Source code for pyhealth.models.rnn

from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

from pyhealth.datasets import SampleEHRDataset
from pyhealth.models import BaseModel

# VALID_OPERATION_LEVEL = ["visit", "event"]


[docs]class RNNLayer(nn.Module): """Recurrent neural network layer. This layer wraps the PyTorch RNN layer with masking and dropout support. It is used in the RNN model. But it can also be used as a standalone layer. Args: input_size: input feature size. hidden_size: hidden feature size. rnn_type: type of rnn, one of "RNN", "LSTM", "GRU". Default is "GRU". num_layers: number of recurrent layers. Default is 1. dropout: dropout rate. If non-zero, introduces a Dropout layer before each RNN layer. Default is 0.5. bidirectional: whether to use bidirectional recurrent layers. If True, a fully-connected layer is applied to the concatenation of the forward and backward hidden states to reduce the dimension to hidden_size. Default is False. Examples: >>> from pyhealth.models import RNNLayer >>> input = torch.randn(3, 128, 5) # [batch size, sequence len, input_size] >>> layer = RNNLayer(5, 64) >>> outputs, last_outputs = layer(input) >>> outputs.shape torch.Size([3, 128, 64]) >>> last_outputs.shape torch.Size([3, 64]) """ def __init__( self, input_size: int, hidden_size: int, rnn_type: str = "GRU", num_layers: int = 1, dropout: float = 0.5, bidirectional: bool = False, ): super(RNNLayer, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.rnn_type = rnn_type self.num_layers = num_layers self.dropout = dropout self.bidirectional = bidirectional self.dropout_layer = nn.Dropout(dropout) self.num_directions = 2 if bidirectional else 1 rnn_module = getattr(nn, rnn_type) self.rnn = rnn_module( input_size, hidden_size, num_layers=num_layers, dropout=dropout if num_layers > 1 else 0, bidirectional=bidirectional, batch_first=True, ) if bidirectional: self.down_projection = nn.Linear(hidden_size * 2, hidden_size)
[docs] def forward( self, x: torch.tensor, mask: Optional[torch.tensor] = None, ) -> Tuple[torch.tensor, torch.tensor]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, input size]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Returns: outputs: a tensor of shape [batch size, sequence len, hidden size], containing the output features for each time step. last_outputs: a tensor of shape [batch size, hidden size], containing the output features for the last time step. """ # pytorch's rnn will only apply dropout between layers x = self.dropout_layer(x) batch_size = x.size(0) if mask is None: lengths = torch.full( size=(batch_size,), fill_value=x.size(1), dtype=torch.int64 ) else: lengths = torch.sum(mask.int(), dim=-1).cpu() x = rnn_utils.pack_padded_sequence( x, lengths, batch_first=True, enforce_sorted=False ) outputs, _ = self.rnn(x) outputs, _ = rnn_utils.pad_packed_sequence(outputs, batch_first=True) if not self.bidirectional: last_outputs = outputs[torch.arange(batch_size), (lengths - 1), :] return outputs, last_outputs else: outputs = outputs.view(batch_size, outputs.shape[1], 2, -1) f_last_outputs = outputs[torch.arange(batch_size), (lengths - 1), 0, :] b_last_outputs = outputs[:, 0, 1, :] last_outputs = torch.cat([f_last_outputs, b_last_outputs], dim=-1) outputs = outputs.view(batch_size, outputs.shape[1], -1) last_outputs = self.down_projection(last_outputs) outputs = self.down_projection(outputs) return outputs, last_outputs
[docs]class RNN(BaseModel): """Recurrent neural network model. This model applies a separate RNN layer for each feature, and then concatenates the final hidden states of each RNN layer. The concatenated hidden states are then fed into a fully connected layer to make predictions. Note: We use separate rnn layers for different feature_keys. Currently, we automatically support different input formats: - code based input (need to use the embedding table later) - float/int based value input We follow the current convention for the rnn model: - case 1. [code1, code2, code3, ...] - we will assume the code follows the order; our model will encode each code into a vector and apply rnn on the code level - case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], ...] - we will assume the inner bracket follows the order; our model first use the embedding table to encode each code into a vector and then use average/mean pooling to get one vector for one inner bracket; then use rnn one the braket level - case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run rnn directly on the inner bracket level, similar to case 1 after embedding table - case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run rnn directly on the inner bracket level, similar to case 2 after embedding table Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". embedding_dim: the embedding dimension. Default is 128. hidden_dim: the hidden dimension. Default is 128. **kwargs: other parameters for the RNN layer. Examples: >>> from pyhealth.datasets import SampleEHRDataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 ... "list_list_vectors": [ ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], ... [[7.7, 8.5, 9.4]], ... ], ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "list_codes": [ ... "55154191800", ... "551541928", ... "55154192800", ... "705182798", ... "70518279800", ... ], ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], ... "list_list_codes": [["A04A", "B035", "C129"]], ... "list_list_vectors": [ ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ... ], ... "label": 0, ... }, ... ] >>> dataset = SampleEHRDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import RNN >>> model = RNN( ... dataset=dataset, ... feature_keys=[ ... "list_codes", ... "list_vectors", ... "list_list_codes", ... "list_list_vectors", ... ], ... label_key="label", ... mode="binary", ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(0.8056, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'y_prob': tensor([[0.5906], [0.6620]], grad_fn=<SigmoidBackward0>), 'y_true': tensor([[1.], [0.]]), 'logit': tensor([[0.3666], [0.6721]], grad_fn=<AddmmBackward0>) } >>> """ def __init__( self, dataset: SampleEHRDataset, feature_keys: List[str], label_key: str, mode: str, pretrained_emb: str = None, embedding_dim: int = 128, hidden_dim: int = 128, **kwargs ): super(RNN, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, mode=mode, pretrained_emb=pretrained_emb, ) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim # validate kwargs for RNN layer if "input_size" in kwargs: raise ValueError("input_size is determined by embedding_dim") if "hidden_size" in kwargs: raise ValueError("hidden_size is determined by hidden_dim") # the key of self.feat_tokenizers only contains the code based inputs self.feat_tokenizers = {} self.label_tokenizer = self.get_label_tokenizer() # the key of self.embeddings only contains the code based inputs self.embeddings = nn.ModuleDict() # the key of self.linear_layers only contains the float/int based inputs self.linear_layers = nn.ModuleDict() # add feature RNN layers for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] # sanity check if input_info["type"] not in [str, float, int]: raise ValueError( "RNN only supports str code, float and int as input types" ) elif (input_info["type"] == str) and (input_info["dim"] not in [2, 3]): raise ValueError( "RNN only supports 2-dim or 3-dim str code as input types" ) elif (input_info["type"] in [float, int]) and ( input_info["dim"] not in [2, 3] ): raise ValueError( "RNN only supports 2-dim or 3-dim float and int as input types" ) # for code based input, we need Type # for float/int based input, we need Type, input_dim self.add_feature_transform_layer(feature_key, input_info) self.rnn = nn.ModuleDict() for feature_key in feature_keys: self.rnn[feature_key] = RNNLayer( input_size=embedding_dim, hidden_size=hidden_dim, **kwargs ) output_size = self.get_output_size(self.label_tokenizer) self.fc = nn.Linear(len(self.feature_keys) * self.hidden_dim, output_size)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. The label `kwargs[self.label_key]` is a list of labels for each patient. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Returns: A dictionary with the following keys: loss: a scalar tensor representing the loss. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels. """ patient_emb = [] for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] dim_, type_ = input_info["dim"], input_info["type"] # for case 1: [code1, code2, code3, ...] if (dim_ == 2) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_2d( kwargs[feature_key] ) # (patient, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, event) mask = torch.any(x !=0, dim=2) # for case 2: [[code1, code2], [code3, ...], ...] elif (dim_ == 3) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_3d( kwargs[feature_key] ) # (patient, visit, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, visit, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) # (patient, visit) mask = torch.any(x !=0, dim=2) # for case 3: [[1.5, 2.0, 0.0], ...] elif (dim_ == 2) and (type_ in [float, int]): x, mask = self.padding2d(kwargs[feature_key]) # (patient, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, event, embedding_dim) x = self.linear_layers[feature_key](x) # (patient, event) mask = mask.bool().to(self.device) # for case 4: [[[1.5, 2.0, 0.0], [1.8, 2.4, 6.0]], ...] elif (dim_ == 3) and (type_ in [float, int]): x, mask = self.padding3d(kwargs[feature_key]) # (patient, visit, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) x = self.linear_layers[feature_key](x) # (patient, event) mask = mask[:, :, 0] mask = mask.bool().to(self.device) else: raise NotImplementedError # transform x to (patient, event, embedding_dim) if self.pretrained_emb != None: x = self.linear_layers[feature_key](x) _, x = self.rnn[feature_key](x, mask) patient_emb.append(x) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = {"loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits} if kwargs.get("embed", False): results["embed"] = patient_emb return results
if __name__ == "__main__": from pyhealth.datasets import SampleEHRDataset samples = [ { "patient_id": "patient-0", "visit_id": "visit-0", # "single_vector": [1, 2, 3], "list_codes": ["505800458", "50580045810", "50580045811"], # NDC "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 "list_list_vectors": [ [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], [[7.7, 8.5, 9.4]], ], "label": 1, }, { "patient_id": "patient-0", "visit_id": "visit-1", # "single_vector": [1, 5, 8], "list_codes": [ "55154191800", "551541928", "55154192800", "705182798", "70518279800", ], "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], "list_list_codes": [["A04A", "B035", "C129"]], "list_list_vectors": [ [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ], "label": 0, }, ] # dataset dataset = SampleEHRDataset(samples=samples, dataset_name="test") # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) # model model = RNN( dataset=dataset, feature_keys=[ "list_codes", "list_vectors", "list_list_codes", "list_list_vectors", ], label_key="label", mode="binary", ) # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()