pyhealth.models.CNN#

The separate callable CNNLayer and the complete CNN model.

class pyhealth.models.CNNLayer(input_size, hidden_size, num_layers=1)[source]#

Bases: Module

Convolutional neural network layer.

This layer stacks multiple CNN blocks and applies adaptive average pooling at the end. It is used in the CNN model. But it can also be used as a standalone layer.

Parameters:
  • input_size (int) – input feature size.

  • hidden_size (int) – hidden feature size.

  • num_layers (int) – number of convolutional layers. Default is 1.

Examples

>>> from pyhealth.models import CNNLayer
>>> input = torch.randn(3, 128, 5)  # [batch size, sequence len, input_size]
>>> layer = CNNLayer(5, 64)
>>> outputs, last_outputs = layer(input)
>>> outputs.shape
torch.Size([3, 128, 64])
>>> last_outputs.shape
torch.Size([3, 64])
forward(x)[source]#

Forward propagation.

Parameters:

x (tensor) – a tensor of shape [batch size, sequence len, input size].

Returns:

a tensor of shape [batch size, sequence len, hidden size],

containing the output features for each time step.

pooled_outputs: a tensor of shape [batch size, hidden size], containing

the pooled output features.

Return type:

outputs

training: bool#
class pyhealth.models.CNN(dataset, feature_keys, label_key, mode, embedding_dim=128, hidden_dim=128, **kwargs)[source]#

Bases: BaseModel

Convolutional neural network model.

This model applies a separate CNN layer for each feature, and then concatenates the final hidden states of each CNN layer. The concatenated hidden states are then fed into a fully connected layer to make predictions.

Note

We use separate CNN layers for different feature_keys. Currentluy, we automatically support different input formats:

  • code based input (need to use the embedding table later)

  • float/int based value input

We follow the current convention for the CNN model:
  • case 1. [code1, code2, code3, …]
    • we will assume the code follows the order; our model will encode

    each code into a vector and apply CNN on the code level

  • case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], …]
    • we will assume the inner bracket follows the order; our model first

    use the embedding table to encode each code into a vector and then use average/mean pooling to get one vector for one inner bracket; then use CNN one the braket level

  • case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], …]
    • this case only makes sense when each inner bracket has the same length;

    we assume each dimension has the same meaning; we run CNN directly on the inner bracket level, similar to case 1 after embedding table

  • case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], …]
    • this case only makes sense when each inner bracket has the same length;

    we assume each dimension has the same meaning; we run CNN directly on the inner bracket level, similar to case 2 after embedding table

Parameters:
  • dataset (SampleEHRDataset) – the dataset to train the model. It is used to query certain information such as the set of all tokens.

  • feature_keys (List[str]) – list of keys in samples to use as features, e.g. [“conditions”, “procedures”].

  • label_key (str) – key in samples to use as label (e.g., “drugs”).

  • mode (str) – one of “binary”, “multiclass”, or “multilabel”.

  • embedding_dim (int) – the embedding dimension. Default is 128.

  • hidden_dim (int) – the hidden dimension. Default is 128.

  • **kwargs – other parameters for the CNN layer.

Examples

>>> from pyhealth.datasets import SampleEHRDataset
>>> samples = [
...         {
...             "patient_id": "patient-0",
...             "visit_id": "visit-0",
...             "list_codes": ["505800458", "50580045810", "50580045811"],  # NDC
...             "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]],
...             "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]],  # ATC-4
...             "list_list_vectors": [
...                 [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]],
...                 [[7.7, 8.5, 9.4]],
...             ],
...             "label": 1,
...         },
...         {
...             "patient_id": "patient-0",
...             "visit_id": "visit-1",
...             "list_codes": [
...                 "55154191800",
...                 "551541928",
...                 "55154192800",
...                 "705182798",
...                 "70518279800",
...             ],
...             "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7]],
...             "list_list_codes": [["A04A", "B035", "C129"], ["A07B", "A07C"]],
...             "list_list_vectors": [
...                 [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6]],
...                 [[7.7, 8.4, 1.3]],
...             ],
...             "label": 0,
...         },
...     ]
>>> dataset = SampleEHRDataset(samples=samples, dataset_name="test")
>>>
>>> from pyhealth.models import CNN
>>> model = CNN(
...         dataset=dataset,
...         feature_keys=[
...             "list_codes",
...             "list_vectors",
...             "list_list_codes",
...             "list_list_vectors",
...         ],
...         label_key="label",
...         mode="binary",
...     )
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
>>> data_batch = next(iter(train_loader))
>>>
>>> ret = model(**data_batch)
>>> print(ret)
{
    'loss': tensor(0.8872, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
    'y_prob': tensor([[0.5008], [0.6614]], grad_fn=<SigmoidBackward0>),
    'y_true': tensor([[1.], [0.]]),
    'logit': tensor([[0.0033], [0.6695]], grad_fn=<AddmmBackward0>)
}
>>>
forward(**kwargs)[source]#

Forward propagation.

The label kwargs[self.label_key] is a list of labels for each patient.

Parameters:

**kwargs – keyword arguments for the model. The keys must contain all the feature keys and the label key.

Returns:

loss: a scalar tensor representing the loss. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels.

Return type:

A dictionary with the following keys

training: bool#