Source code for pyhealth.models.tcn

from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.nn.utils import weight_norm

from pyhealth.datasets import SampleEHRDataset
from pyhealth.models import BaseModel
from pyhealth.models.utils import get_last_visit

# VALID_OPERATION_LEVEL = ["visit", "event"]


# From TCN original paper https://github.com/locuslab/TCN
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, : -self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(
        self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2
    ):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(
            nn.Conv1d(
                n_inputs,
                n_outputs,
                kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
            ),
            dim=None,
        )
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(
            nn.Conv1d(
                n_outputs,
                n_outputs,
                kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
            ),
            dim=None,
        )
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(
            self.conv1,
            self.chomp1,
            self.relu1,
            self.dropout1,
            self.conv2,
            self.chomp2,
            self.relu2,
            self.dropout2,
        )
        self.downsample = (
            nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        )
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


[docs]class TCNLayer(nn.Module): """Temporal Convolutional Networks layer. Shaojie Bai et al. An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling. This layer wraps the PyTorch TCN layer with masking and dropout support. It is used in the TCN model. But it can also be used as a standalone layer. Args: input_dim: input feature size. num_channels: int or list of ints. If int, the depth will be automatically decided by the max_seq_length. If list, number of channels in each layer. max_seq_length: max sequence length. Used to compute the depth of the TCN. kernel_size: kernel size of the TCN. dropout: dropout rate. If non-zero, introduces a Dropout layer before each TCN blocks. Default is 0.5. Examples: >>> from pyhealth.models import TCNLayer >>> input = torch.randn(3, 128, 5) # [batch size, sequence len, input_size] >>> layer = TCNLayer(5, 64) >>> outputs, last_outputs = layer(input) >>> outputs.shape torch.Size([3, 128, 64]) >>> last_outputs.shape torch.Size([3, 64]) """ def __init__( self, input_dim: int, num_channels: int = 128, max_seq_length: int = 20, kernel_size: int = 2, dropout: float = 0.5, ): super(TCNLayer, self).__init__() self.num_channels = num_channels layers = [] # We compute automatically the depth based on the desired seq_length. if isinstance(num_channels, int) and max_seq_length: num_channels = [num_channels] * int( np.ceil(np.log(max_seq_length / 2) / np.log(kernel_size)) ) elif isinstance(num_channels, int) and not max_seq_length: raise Exception( "a maximum sequence length needs to be provided if num_channels is int" ) else: pass num_levels = len(num_channels) for i in range(num_levels): dilation_size = 2**i in_channels = input_dim if i == 0 else num_channels[i - 1] out_channels = num_channels[i] layers += [ TemporalBlock( in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, padding=(kernel_size - 1) * dilation_size, dropout=dropout, ) ] self.network = nn.Sequential(*layers)
[docs] def forward( self, x: torch.tensor, mask: Optional[torch.tensor] = None, ) -> Tuple[torch.tensor, torch.tensor]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, input size]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Returns: last_out: a tensor of shape [batch size, hidden size], containing the output features for the last time step. out: a tensor of shape [batch size, sequence len, hidden size], containing the output features for each time step. """ out = self.network(x.permute(0, 2, 1)).permute(0, 2, 1) last_out = get_last_visit(out, mask) return last_out, out
[docs]class TCN(BaseModel): """Temporal Convolutional Networks model. This model applies a separate TCN layer for each feature, and then concatenates the final hidden states of each TCN layer. The concatenated hidden states are then fed into a fully connected layer to make predictions. Note: We use separate TCN layers for different feature_keys. Currently, we automatically support different input formats: - code based input (need to use the embedding table later) - float/int based value input We follow the current convention for the TCN model: - case 1. [code1, code2, code3, ...] - we will assume the code follows the order; our model will encode each code into a vector and apply TCN on the code level - case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], ...] - we will assume the inner bracket follows the order; our model first use the embedding table to encode each code into a vector and then use average/mean pooling to get one vector for one inner bracket; then use TCN one the braket level - case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run TCN directly on the inner bracket level, similar to case 1 after embedding table - case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run TCN directly on the inner bracket level, similar to case 2 after embedding table Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". embedding_dim: the embedding dimension. Default is 128. num_channels: the number of channels in the TCN layer. Default is 128. **kwargs: other parameters for the TCN layer. Examples: >>> from pyhealth.datasets import SampleEHRDataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 ... "list_list_vectors": [ ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], ... [[7.7, 8.5, 9.4]], ... ], ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "list_codes": [ ... "55154191800", ... "551541928", ... "55154192800", ... "705182798", ... "70518279800", ... ], ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], ... "list_list_codes": [["A04A", "B035", "C129"]], ... "list_list_vectors": [ ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ... ], ... "label": 0, ... }, ... ] >>> dataset = SampleEHRDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import TCN >>> model = TCN( ... dataset=dataset, ... feature_keys=[ ... "list_codes", ... "list_vectors", ... "list_list_codes", ... "list_list_vectors", ... ], ... label_key="label", ... mode="binary", ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(1.1641, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'y_prob': tensor([[0.6837], [0.3081]], grad_fn=<SigmoidBackward0>), 'y_true': tensor([[0.], [1.]]), 'logit': tensor([[ 0.7706], [-0.8091]], grad_fn=<AddmmBackward0>) } >>> """ def __init__( self, dataset: SampleEHRDataset, feature_keys: List[str], label_key: str, mode: str, embedding_dim: int = 128, num_channels: int = 128, **kwargs ): super(TCN, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, mode=mode, ) self.embedding_dim = embedding_dim self.num_channels = num_channels if "input_dim" in kwargs: raise ValueError("input_dim is determined by embedding_dim") # the key of self.feat_tokenizers only contains the code based inputs self.feat_tokenizers = {} self.label_tokenizer = self.get_label_tokenizer() # the key of self.embeddings only contains the code based inputs self.embeddings = nn.ModuleDict() # the key of self.linear_layers only contains the float/int based inputs self.linear_layers = nn.ModuleDict() for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] # sanity check if input_info["type"] not in [str, float, int]: raise ValueError( "TCN only supports str code, float and int as input types" ) elif (input_info["type"] == str) and (input_info["dim"] not in [2, 3]): raise ValueError( "TCN only supports 2-dim or 3-dim str code as input types" ) elif (input_info["type"] in [float, int]) and ( input_info["dim"] not in [2, 3] ): raise ValueError( "TCN only supports 2-dim or 3-dim float and int as input types" ) else: pass # for code based input, we need Type # for float/int based input, we need Type, input_dim self.add_feature_transform_layer(feature_key, input_info) self.tcn = nn.ModuleDict() for feature_key in feature_keys: self.tcn[feature_key] = TCNLayer( input_dim=embedding_dim, num_channels=num_channels, **kwargs ) output_size = self.get_output_size(self.label_tokenizer) self.fc = nn.Linear(len(self.feature_keys) * self.num_channels, output_size)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. The label `kwargs[self.label_key]` is a list of labels for each patient. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Returns: A dictionary with the following keys: loss: a scalar tensor representing the loss. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels. """ patient_emb = [] for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] dim_, type_ = input_info["dim"], input_info["type"] # for case 1: [code1, code2, code3, ...] if (dim_ == 2) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_2d( kwargs[feature_key] ) # (patient, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, event) mask = torch.any(x !=0, dim=2) # for case 2: [[code1, code2], [code3, ...], ...] elif (dim_ == 3) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_3d( kwargs[feature_key] ) # (patient, visit, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, visit, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) # (patient, visit) mask = torch.any(x !=0, dim=2) # for case 3: [[1.5, 2.0, 0.0], ...] elif (dim_ == 2) and (type_ in [float, int]): x, mask = self.padding2d(kwargs[feature_key]) # (patient, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, event, embedding_dim) x = self.linear_layers[feature_key](x) # (patient, event) mask = mask.bool().to(self.device) # for case 4: [[[1.5, 2.0, 0.0], [1.8, 2.4, 6.0]], ...] elif (dim_ == 3) and (type_ in [float, int]): x, mask = self.padding3d(kwargs[feature_key]) # (patient, visit, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) x = self.linear_layers[feature_key](x) # (patient, event) mask = mask[:, :, 0] mask = mask.bool().to(self.device) else: raise NotImplementedError x, _ = self.tcn[feature_key](x, mask) patient_emb.append(x) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = patient_emb return results
if __name__ == "__main__": from pyhealth.datasets import SampleEHRDataset samples = [ { "patient_id": "patient-0", "visit_id": "visit-0", # "single_vector": [1, 2, 3], "list_codes": ["505800458", "50580045810", "50580045811"], # NDC "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 "list_list_vectors": [ [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], [[7.7, 8.5, 9.4]], ], "label": 1, }, { "patient_id": "patient-0", "visit_id": "visit-1", # "single_vector": [1, 5, 8], "list_codes": [ "55154191800", "551541928", "55154192800", "705182798", "70518279800", ], "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], "list_list_codes": [["A04A", "B035", "C129"]], "list_list_vectors": [ [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ], "label": 0, }, ] # dataset dataset = SampleEHRDataset(samples=samples, dataset_name="test") # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) # model model = TCN( dataset=dataset, feature_keys=[ "list_codes", "list_vectors", "list_list_codes", "list_list_vectors", ], label_key="label", mode="binary", ) # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()