from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from pyhealth.datasets import SampleEHRDataset
from pyhealth.models import BaseModel
VALID_OPERATION_LEVEL = ["visit", "event"]
class CNNBlock(nn.Module):
"""Convolutional neural network block.
This block wraps the PyTorch convolutional neural network layer with batch
normalization and residual connection. It is used in the CNN layer.
Args:
in_channels: number of input channels.
out_channels: number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
super(CNNBlock, self).__init__()
self.conv1 = nn.Sequential(
# stride=1 by default
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm1d(out_channels),
nn.ReLU(),
)
self.conv2 = nn.Sequential(
# stride=1 by default
nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm1d(out_channels),
)
self.downsample = None
if in_channels != out_channels:
self.downsample = nn.Sequential(
# stride=1, padding=0 by default
nn.Conv1d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm1d(out_channels),
)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward propagation.
Args:
x: input tensor of shape [batch size, in_channels, *].
Returns:
output tensor of shape [batch size, out_channels, *].
"""
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
[docs]class CNNLayer(nn.Module):
"""Convolutional neural network layer.
This layer stacks multiple CNN blocks and applies adaptive average pooling
at the end. It is used in the CNN model. But it can also be used as a
standalone layer.
Args:
input_size: input feature size.
hidden_size: hidden feature size.
num_layers: number of convolutional layers. Default is 1.
Examples:
>>> from pyhealth.models import CNNLayer
>>> input = torch.randn(3, 128, 5) # [batch size, sequence len, input_size]
>>> layer = CNNLayer(5, 64)
>>> outputs, last_outputs = layer(input)
>>> outputs.shape
torch.Size([3, 128, 64])
>>> last_outputs.shape
torch.Size([3, 64])
"""
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
):
super(CNNLayer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.cnn = nn.ModuleDict()
for i in range(num_layers):
in_channels = input_size if i == 0 else hidden_size
out_channels = hidden_size
self.cnn[f"CNN-{i}"] = CNNBlock(in_channels, out_channels)
self.pooling = nn.AdaptiveAvgPool1d(1)
[docs] def forward(self, x: torch.tensor) -> Tuple[torch.tensor, torch.tensor]:
"""Forward propagation.
Args:
x: a tensor of shape [batch size, sequence len, input size].
Returns:
outputs: a tensor of shape [batch size, sequence len, hidden size],
containing the output features for each time step.
pooled_outputs: a tensor of shape [batch size, hidden size], containing
the pooled output features.
"""
# [batch size, input size, sequence len]
x = x.permute(0, 2, 1)
for idx in range(len(self.cnn)):
x = self.cnn[f"CNN-{idx}"](x)
outputs = x.permute(0, 2, 1)
# pooling
pooled_outputs = self.pooling(x).squeeze(-1)
return outputs, pooled_outputs
[docs]class CNN(BaseModel):
"""Convolutional neural network model.
This model applies a separate CNN layer for each feature, and then concatenates
the final hidden states of each CNN layer. The concatenated hidden states are
then fed into a fully connected layer to make predictions.
Note:
We use separate CNN layers for different feature_keys.
Currentluy, we automatically support different input formats:
- code based input (need to use the embedding table later)
- float/int based value input
We follow the current convention for the CNN model:
- case 1. [code1, code2, code3, ...]
- we will assume the code follows the order; our model will encode
each code into a vector and apply CNN on the code level
- case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], ...]
- we will assume the inner bracket follows the order; our model first
use the embedding table to encode each code into a vector and then use
average/mean pooling to get one vector for one inner bracket; then use
CNN one the braket level
- case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], ...]
- this case only makes sense when each inner bracket has the same length;
we assume each dimension has the same meaning; we run CNN directly
on the inner bracket level, similar to case 1 after embedding table
- case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], ...]
- this case only makes sense when each inner bracket has the same length;
we assume each dimension has the same meaning; we run CNN directly
on the inner bracket level, similar to case 2 after embedding table
Args:
dataset: the dataset to train the model. It is used to query certain
information such as the set of all tokens.
feature_keys: list of keys in samples to use as features,
e.g. ["conditions", "procedures"].
label_key: key in samples to use as label (e.g., "drugs").
mode: one of "binary", "multiclass", or "multilabel".
embedding_dim: the embedding dimension. Default is 128.
hidden_dim: the hidden dimension. Default is 128.
**kwargs: other parameters for the CNN layer.
Examples:
>>> from pyhealth.datasets import SampleEHRDataset
>>> samples = [
... {
... "patient_id": "patient-0",
... "visit_id": "visit-0",
... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC
... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]],
... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4
... "list_list_vectors": [
... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]],
... [[7.7, 8.5, 9.4]],
... ],
... "label": 1,
... },
... {
... "patient_id": "patient-0",
... "visit_id": "visit-1",
... "list_codes": [
... "55154191800",
... "551541928",
... "55154192800",
... "705182798",
... "70518279800",
... ],
... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7]],
... "list_list_codes": [["A04A", "B035", "C129"], ["A07B", "A07C"]],
... "list_list_vectors": [
... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6]],
... [[7.7, 8.4, 1.3]],
... ],
... "label": 0,
... },
... ]
>>> dataset = SampleEHRDataset(samples=samples, dataset_name="test")
>>>
>>> from pyhealth.models import CNN
>>> model = CNN(
... dataset=dataset,
... feature_keys=[
... "list_codes",
... "list_vectors",
... "list_list_codes",
... "list_list_vectors",
... ],
... label_key="label",
... mode="binary",
... )
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
>>> data_batch = next(iter(train_loader))
>>>
>>> ret = model(**data_batch)
>>> print(ret)
{
'loss': tensor(0.8872, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
'y_prob': tensor([[0.5008], [0.6614]], grad_fn=<SigmoidBackward0>),
'y_true': tensor([[1.], [0.]]),
'logit': tensor([[0.0033], [0.6695]], grad_fn=<AddmmBackward0>)
}
>>>
"""
def __init__(
self,
dataset: SampleEHRDataset,
feature_keys: List[str],
label_key: str,
mode: str,
embedding_dim: int = 128,
hidden_dim: int = 128,
**kwargs,
):
super(CNN, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
mode=mode,
)
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
# validate kwargs for CNN layer
if "input_size" in kwargs:
raise ValueError("input_size is determined by embedding_dim")
if "hidden_size" in kwargs:
raise ValueError("hidden_size is determined by hidden_dim")
# the key of self.feat_tokenizers only contains the code based inputs
self.feat_tokenizers = {}
self.label_tokenizer = self.get_label_tokenizer()
# the key of self.embeddings only contains the code based inputs
self.embeddings = nn.ModuleDict()
# the key of self.linear_layers only contains the float/int based inputs
self.linear_layers = nn.ModuleDict()
# add feature CNN layers
for feature_key in self.feature_keys:
input_info = self.dataset.input_info[feature_key]
# sanity check
if input_info["type"] not in [str, float, int]:
raise ValueError(
"CNN only supports str code, float and int as input types"
)
elif (input_info["type"] == str) and (input_info["dim"] not in [2, 3]):
raise ValueError(
"CNN only supports 2-dim or 3-dim str code as input types"
)
elif (input_info["type"] in [float, int]) and (
input_info["dim"] not in [2, 3]
):
raise ValueError(
"CNN only supports 2-dim or 3-dim float and int as input types"
)
# for code based input, we need Type
# for float/int based input, we need Type, input_dim
self.add_feature_transform_layer(feature_key, input_info)
self.cnn = nn.ModuleDict()
for feature_key in feature_keys:
self.cnn[feature_key] = CNNLayer(
input_size=embedding_dim, hidden_size=hidden_dim, **kwargs
)
output_size = self.get_output_size(self.label_tokenizer)
self.fc = nn.Linear(len(self.feature_keys) * self.hidden_dim, output_size)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation.
The label `kwargs[self.label_key]` is a list of labels for each patient.
Args:
**kwargs: keyword arguments for the model. The keys must contain
all the feature keys and the label key.
Returns:
A dictionary with the following keys:
loss: a scalar tensor representing the loss.
y_prob: a tensor representing the predicted probabilities.
y_true: a tensor representing the true labels.
"""
patient_emb = []
for feature_key in self.feature_keys:
input_info = self.dataset.input_info[feature_key]
dim_, type_ = input_info["dim"], input_info["type"]
# for case 1: [code1, code2, code3, ...]
if (dim_ == 2) and (type_ == str):
x = self.feat_tokenizers[feature_key].batch_encode_2d(
kwargs[feature_key]
)
# (patient, event)
x = torch.tensor(x, dtype=torch.long, device=self.device)
# (patient, event, embedding_dim)
x = self.embeddings[feature_key](x)
# for case 2: [[code1, code2], [code3, ...], ...]
elif (dim_ == 3) and (type_ == str):
x = self.feat_tokenizers[feature_key].batch_encode_3d(
kwargs[feature_key]
)
# (patient, visit, event)
x = torch.tensor(x, dtype=torch.long, device=self.device)
# (patient, visit, event, embedding_dim)
x = self.embeddings[feature_key](x)
# (patient, visit, embedding_dim)
x = torch.sum(x, dim=2)
# for case 3: [[1.5, 2.0, 0.0], ...]
elif (dim_ == 2) and (type_ in [float, int]):
x, _ = self.padding2d(kwargs[feature_key])
# (patient, event, values)
x = torch.tensor(x, dtype=torch.float, device=self.device)
# (patient, event, embedding_dim)
x = self.linear_layers[feature_key](x)
# for case 4: [[[1.5, 2.0, 0.0], [1.8, 2.4, 6.0]], ...]
elif (dim_ == 3) and (type_ in [float, int]):
x, _ = self.padding3d(kwargs[feature_key])
# (patient, visit, event, values)
x = torch.tensor(x, dtype=torch.float, device=self.device)
# (patient, visit, embedding_dim)
x = torch.sum(x, dim=2)
x = self.linear_layers[feature_key](x)
else:
raise NotImplementedError
_, x = self.cnn[feature_key](x)
patient_emb.append(x)
patient_emb = torch.cat(patient_emb, dim=1)
# (patient, label_size)
logits = self.fc(patient_emb)
# obtain y_true, loss, y_prob
y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
loss = self.get_loss_function()(logits, y_true)
y_prob = self.prepare_y_prob(logits)
results = {
"loss": loss,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
}
if kwargs.get("embed", False):
results["embed"] = patient_emb
return results
if __name__ == "__main__":
from pyhealth.datasets import SampleEHRDataset
samples = [
{
"patient_id": "patient-0",
"visit_id": "visit-0",
# "single_vector": [1, 2, 3],
"list_codes": ["505800458", "50580045810", "50580045811"], # NDC
"list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]],
"list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4
"list_list_vectors": [
[[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]],
[[7.7, 8.5, 9.4]],
],
"label": 1,
},
{
"patient_id": "patient-0",
"visit_id": "visit-1",
# "single_vector": [1, 5, 8],
"list_codes": [
"55154191800",
"551541928",
"55154192800",
"705182798",
"70518279800",
],
"list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]],
"list_list_codes": [["A04A", "B035", "C129"]],
"list_list_vectors": [
[[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]],
],
"label": 0,
},
]
# dataset
dataset = SampleEHRDataset(samples=samples, dataset_name="test")
print(dataset.input_info)
# data loader
from pyhealth.datasets import get_dataloader
train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
# model
model = CNN(
dataset=dataset,
feature_keys=[
"list_codes",
"list_vectors",
"list_list_codes",
"list_list_vectors",
],
label_key="label",
mode="binary",
)
# data batch
data_batch = next(iter(train_loader))
# try the model
ret = model(**data_batch)
print(ret)
# try loss backward
ret["loss"].backward()