pyhealth.models.StageNet#

The separate callable StageNetLayer and the complete StageNet model.

class pyhealth.models.StageNetLayer(input_dim, chunk_size=128, conv_size=10, levels=3, dropconnect=0.3, dropout=0.3, dropres=0.3)[source]#

Bases: Module

StageNet layer.

Paper: Stagenet: Stage-aware neural networks for health risk prediction. WWW 2020.

This layer is used in the StageNet model. But it can also be used as a standalone layer.

Parameters:
  • input_dim (int) – dynamic feature size.

  • chunk_size (int) – the chunk size for the StageNet layer. Default is 128.

  • levels (int) – the number of levels for the StageNet layer. levels * chunk_size = hidden_dim in the RNN. Smaller chunk size and more levels can capture more detailed patient status variations. Default is 3.

  • conv_size (int) – the size of the convolutional kernel. Default is 10.

  • dropconnect (int) – the dropout rate for the dropconnect. Default is 0.3.

  • dropout (int) – the dropout rate for the dropout. Default is 0.3.

  • dropres (int) – the dropout rate for the residual connection. Default is 0.3.

Examples

>>> from pyhealth.models import StageNetLayer
>>> input = torch.randn(3, 128, 64)  # [batch size, sequence len, feature_size]
>>> layer = StageNetLayer(64)
>>> c, _, _ = layer(input)
>>> c.shape
torch.Size([3, 384])
cumax(x, mode='l2r')[source]#
step(inputs, c_last, h_last, interval, device)[source]#
forward(x, time=None, mask=None)[source]#

Forward propagation.

Parameters:
  • x (tensor) – a tensor of shape [batch size, sequence len, input_dim].

  • static – a tensor of shape [batch size, static_dim].

  • mask (Optional[tensor]) – an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid.

Returns:

a tensor of shape [batch size, chunk_size*levels] representing the

patient embedding.

outputs: a tensor of shape [batch size, sequence len, chunk_size*levels] representing the patient at each time step.

Return type:

last_output

training: bool#
class pyhealth.models.StageNet(dataset, feature_keys, label_key, mode, time_keys=None, embedding_dim=128, chunk_size=128, levels=3, **kwargs)[source]#

Bases: BaseModel

StageNet model.

Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health risk prediction. WWW 2020.

Note

We use separate StageNet layers for different feature_keys. Currently, we automatically support different input formats:

  • code based input (need to use the embedding table later)

  • float/int based value input

We follow the current convention for the StageNet model:
  • case 1. [code1, code2, code3, …]
    • we will assume the code follows the order; our model will encode

    each code into a vector and apply StageNet on the code level

  • case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], …]
    • we will assume the inner bracket follows the order; our model first

    use the embedding table to encode each code into a vector and then use average/mean pooling to get one vector for one inner bracket; then use StageNet one the braket level

  • case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], …]
    • this case only makes sense when each inner bracket has the same length;

    we assume each dimension has the same meaning; we run StageNet directly on the inner bracket level, similar to case 1 after embedding table

  • case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], …]
    • this case only makes sense when each inner bracket has the same length;

    we assume each dimension has the same meaning; we run StageNet directly on the inner bracket level, similar to case 2 after embedding table

The time interval information specified by time_keys will be used to calculate the memory decay between each visit. If time_keys is None, all visits are treated as the same time interval. For each feature, the time interval should be a two-dimensional float array with shape (time_step, 1).

Parameters:
  • dataset (SampleEHRDataset) – the dataset to train the model. It is used to query certain information such as the set of all tokens.

  • feature_keys (List[str]) – list of keys in samples to use as features, e.g. [“conditions”, “procedures”].

  • label_key (str) – key in samples to use as label (e.g., “drugs”).

  • mode (str) – one of “binary”, “multiclass”, or “multilabel”.

  • time_keys (Optional[List[str]]) – list of keys in samples to use as time interval information for each feature, Default is None. If none, all visits are treated as the same time interval.

  • embedding_dim (int) – the embedding dimension. Default is 128.

  • chunk_size (int) – the chunk size for the StageNet layer. Default is 128.

  • levels (int) – the number of levels for the StageNet layer. levels * chunk_size = hidden_dim in the RNN. Smaller chunk size and more levels can capture more detailed patient status variations. Default is 3.

  • **kwargs – other parameters for the StageNet layer.

Examples

>>> from pyhealth.datasets import SampleEHRDataset
>>> samples = [
...     {
...         "patient_id": "patient-0",
...         "visit_id": "visit-0",
...         # "single_vector": [1, 2, 3],
...         "list_codes": ["505800458", "50580045810", "50580045811"],  # NDC
...         "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]],
...         "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]],  # ATC-4
...         "list_list_vectors": [
...             [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]],
...             [[7.7, 8.5, 9.4]],
...         ],
...         "label": 1,
...         "list_vectors_time": [[0.0], [1.3]],
...         "list_codes_time": [[0.0], [2.0], [1.3]],
...         "list_list_codes_time": [[0.0], [1.5]],
...     },
...     {
...         "patient_id": "patient-0",
...         "visit_id": "visit-1",
...         # "single_vector": [1, 5, 8],
...         "list_codes": [
...             "55154191800",
...             "551541928",
...             "55154192800",
...             "705182798",
...             "70518279800",
...         ],
...         "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]],
...         "list_list_codes": [["A04A", "B035", "C129"]],
...         "list_list_vectors": [
...             [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]],
...         ],
...         "label": 0,
...         "list_vectors_time": [[0.0], [2.0], [1.0]],
...         "list_codes_time": [[0.0], [2.0], [1.3], [1.0], [2.0]],
...         "list_list_codes_time": [[0.0]],
...     },
... ]
>>>
>>> # dataset
>>> dataset = SampleEHRDataset(samples=samples, dataset_name="test")
>>>
>>> # data loader
>>> from pyhealth.datasets import get_dataloader
>>>
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
>>>
>>> # model
>>> model = StageNet(
...     dataset=dataset,
...     feature_keys=[
...         "list_codes",
...         "list_vectors",
...         "list_list_codes",
...         # "list_list_vectors",
...     ],
...     time_keys=["list_codes_time", "list_vectors_time", "list_list_codes_time"],
...     label_key="label",
...     mode="binary",
... )
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
>>> data_batch = next(iter(train_loader))
>>>
>>> ret = model(**data_batch)
>>> print(ret)
{
    'loss': tensor(0.7111, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
    'y_prob': tensor([[0.4815],
                [0.4991]], grad_fn=<SigmoidBackward0>),
    'y_true': tensor([[1.],
                [0.]]),
    'logit': tensor([[-0.0742],
                [-0.0038]], grad_fn=<AddmmBackward0>)
}
>>>
forward(**kwargs)[source]#

Forward propagation.

The label kwargs[self.label_key] is a list of labels for each patient.

Parameters:

**kwargs – keyword arguments for the model. The keys must contain all the feature keys and the label key.

Returns:

loss: a scalar tensor representing the final loss. distance: list of tensors representing the stage variation of the patient. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels.

Return type:

A dictionary with the following keys

training: bool#