Source code for pyhealth.models.rnn

from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel
from pyhealth.processors import (
    DeepNestedFloatsProcessor,
    DeepNestedSequenceProcessor,
    MultiHotProcessor,
    NestedFloatsProcessor,
    NestedSequenceProcessor,
    SequenceProcessor,
    StageNetProcessor,
    StageNetTensorProcessor,
    TensorProcessor,
    TimeseriesProcessor,
)

from .embedding import EmbeddingModel


[docs]class RNNLayer(nn.Module): """Recurrent neural network layer. This layer wraps the PyTorch RNN layer with masking and dropout support. It is used in the RNN model. But it can also be used as a standalone layer. Args: input_size: input feature size. hidden_size: hidden feature size. rnn_type: type of rnn, one of "RNN", "LSTM", "GRU". Default is "GRU". num_layers: number of recurrent layers. Default is 1. dropout: dropout rate. If non-zero, introduces a Dropout layer before each RNN layer. Default is 0.5. bidirectional: whether to use bidirectional recurrent layers. If True, a fully-connected layer is applied to the concatenation of the forward and backward hidden states to reduce the dimension to hidden_size. Default is False. Examples: >>> from pyhealth.models import RNNLayer >>> input = torch.randn(3, 128, 5) # [batch size, sequence len, input_size] >>> layer = RNNLayer(5, 64) >>> outputs, last_outputs = layer(input) >>> outputs.shape torch.Size([3, 128, 64]) >>> last_outputs.shape torch.Size([3, 64]) """ def __init__( self, input_size: int, hidden_size: int, rnn_type: str = "GRU", num_layers: int = 1, dropout: float = 0.5, bidirectional: bool = False, ): super(RNNLayer, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.rnn_type = rnn_type self.num_layers = num_layers self.dropout = dropout self.bidirectional = bidirectional self.dropout_layer = nn.Dropout(dropout) self.num_directions = 2 if bidirectional else 1 rnn_module = getattr(nn, rnn_type) self.rnn = rnn_module( input_size, hidden_size, num_layers=num_layers, dropout=dropout if num_layers > 1 else 0, bidirectional=bidirectional, batch_first=True, ) if bidirectional: self.down_projection = nn.Linear(hidden_size * 2, hidden_size)
[docs] def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, input size]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Samples with all-zero masks are clamped to length 1 to prevent pack_padded_sequence from receiving zero-length sequences. Returns: outputs: a tensor of shape [batch size, sequence len, hidden size], containing the output features for each time step. last_outputs: a tensor of shape [batch size, hidden size], containing the output features for the last time step. """ # pytorch's rnn will only apply dropout between layers x = self.dropout_layer(x) batch_size = x.size(0) if mask is None: lengths = torch.full( size=(batch_size,), fill_value=x.size(1), dtype=torch.int64 ) else: lengths = torch.sum(mask.int(), dim=-1).cpu() # Clamp lengths to at least 1 to handle empty sequences, # matching TCNLayer (tcn.py:186). lengths = torch.clamp(lengths, min=1) # Ensure tensor is contiguous for cuDNN compatibility x = x.contiguous() x = rnn_utils.pack_padded_sequence( x, lengths, batch_first=True, enforce_sorted=False ) outputs, _ = self.rnn(x) outputs, _ = rnn_utils.pad_packed_sequence(outputs, batch_first=True) # Ensure outputs are contiguous after unpacking outputs = outputs.contiguous() if not self.bidirectional: last_outputs = outputs[torch.arange(batch_size), (lengths - 1), :] return outputs, last_outputs else: outputs = outputs.view(batch_size, outputs.shape[1], 2, -1) f_last_outputs = outputs[torch.arange(batch_size), (lengths - 1), 0, :] b_last_outputs = outputs[:, 0, 1, :] last_outputs = torch.cat([f_last_outputs, b_last_outputs], dim=-1) # Ensure view result is contiguous for cuDNN outputs = outputs.view(batch_size, outputs.shape[1], -1).contiguous() last_outputs = self.down_projection(last_outputs) outputs = self.down_projection(outputs) return outputs, last_outputs
[docs]class RNN(BaseModel): """Recurrent neural network model. This model applies a separate RNN layer for each feature, and then concatenates the final hidden states of each RNN layer. The concatenated hidden states are then fed into a fully connected layer to make predictions. Note: We use separate RNN layers for different feature_keys. Currently, we support two types of input formats: - Sequence of codes (e.g., diagnosis codes, procedure codes) - Input format: (batch_size, sequence_length) - Each code is embedded into a vector and RNN is applied on the sequence - Timeseries values (e.g., lab tests, vital signs) - Input format: (batch_size, sequence_length, num_features) - Each timestep contains a fixed number of measurements - RNN is applied directly on the timeseries data Args: dataset (SampleDataset): the dataset to train the model. It is used to query certain information such as the set of all tokens. embedding_dim (int): the embedding dimension. Default is 128. hidden_dim (int): the hidden dimension. Default is 128. **kwargs: other parameters for the RNN layer (e.g., rnn_type, num_layers, dropout, bidirectional). Examples: >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "conditions": ["cond-33", "cond-86", "cond-80"], ... "procedures": ["proc-12", "proc-45"], ... "label": 1, ... }, ... { ... "patient_id": "patient-1", ... "visit_id": "visit-1", ... "conditions": ["cond-12", "cond-52"], ... "procedures": ["proc-23"], ... "label": 0, ... }, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "procedures": "sequence"}, ... output_schema={"label": "binary"}, ... dataset_name="test" ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> >>> model = RNN(dataset=dataset, embedding_dim=128, hidden_dim=64) >>> >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(...), 'y_prob': tensor(...), 'y_true': tensor(...), 'logit': tensor(...) } """ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, hidden_dim: int = 128, **kwargs ): super(RNN, self).__init__( dataset=dataset, ) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim # validate kwargs for RNN layer if "input_size" in kwargs: raise ValueError("input_size is determined by embedding_dim") if "hidden_size" in kwargs: raise ValueError("hidden_size is determined by hidden_dim") assert len(self.label_keys) == 1, "Only one label key is supported if RNN is initialized" self.label_key = self.label_keys[0] self.mode = self.dataset.output_schema[self.label_key] self.embedding_model = EmbeddingModel(dataset, embedding_dim) self.rnn = nn.ModuleDict() for feature_key in self.dataset.input_processors.keys(): self.rnn[feature_key] = RNNLayer( input_size=embedding_dim, hidden_size=hidden_dim, **kwargs ) output_size = self.get_output_size() self.fc = nn.Linear(len(self.feature_keys) * self.hidden_dim, output_size)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. The label `kwargs[self.label_key]` is a list of labels for each patient. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Returns: Dict[str, torch.Tensor]: A dictionary with the following keys: - loss: a scalar tensor representing the loss. - y_prob: a tensor representing the predicted probabilities. - y_true: a tensor representing the true labels. - logit: a tensor representing the logits. - embed (optional): a tensor representing the patient embeddings if requested. """ patient_emb = [] # We need to preprocess kwargs to extract values and masks for EmbeddingModel # because EmbeddingModel expects dict of tensors inputs = {} masks = {} for feature_key in self.feature_keys: feature = kwargs[feature_key] if isinstance(feature, torch.Tensor): feature = (feature,) schema = self.dataset.input_processors[feature_key].schema() value = feature[schema.index("value")] if "value" in schema else None mask = feature[schema.index("mask")] if "mask" in schema else None if value is None: raise ValueError(f"Feature '{feature_key}' must contain 'value' in the schema.") inputs[feature_key] = value if mask is not None: masks[feature_key] = mask embedded = self.embedding_model(inputs, masks=masks) for feature_key in self.feature_keys: x = embedded[feature_key] x_dim_orig = x.dim() if x_dim_orig == 4: # nested_sequence: (B, num_visits, num_codes, D) # @TODO: sum-pooling across codes is a simple baseline. May need to investigate better embeddings for nested codes. x = x.sum(dim=2) # (B, num_visits, D) if feature_key in masks: mask = (masks[feature_key].to(self.device).sum(dim=-1) > 0).int() # (B, V) else: mask = (torch.abs(x).sum(dim=-1) != 0).int() elif x_dim_orig == 2: x = x.unsqueeze(1) mask = None else: # 3D: already (B, T, D) if feature_key in masks: mask = masks[feature_key].to(self.device).int() if mask.dim() == 3: mask = (mask.sum(dim=-1) > 0).int() else: mask = (torch.abs(x).sum(dim=-1) != 0).int() _, x = self.rnn[feature_key](x, mask) patient_emb.append(x) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob y_true = kwargs[self.label_key].to(self.device) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = {"loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits} if kwargs.get("embed", False): results["embed"] = patient_emb return results
[docs]class MultimodalRNN(BaseModel): """Multimodal RNN model that handles both sequential and non-sequential features. This model extends the vanilla RNN to support mixed input modalities: - Sequential features (sequences, timeseries) go through RNN layers - Non-sequential features (multi-hot, tensor) bypass RNN and use embeddings directly The model automatically classifies input features based on their processor types: - Sequential processors (apply RNN): SequenceProcessor, NestedSequenceProcessor, DeepNestedSequenceProcessor, NestedFloatsProcessor, DeepNestedFloatsProcessor, TimeseriesProcessor - Non-sequential processors (embeddings only): MultiHotProcessor, TensorProcessor, StageNetProcessor, StageNetTensorProcessor For sequential features, the model: 1. Embeds the input using EmbeddingModel 2. Applies RNNLayer to get sequential representations 3. Extracts the last hidden state For non-sequential features, the model: 1. Embeds the input using EmbeddingModel 2. Applies mean pooling if needed to reduce to 2D 3. Uses the embedding directly All feature representations are concatenated and passed through a final fully connected layer for predictions. Args: dataset (SampleDataset): the dataset to train the model. It is used to query certain information such as the set of all tokens and processor types. embedding_dim (int): the embedding dimension. Default is 128. hidden_dim (int): the hidden dimension for RNN layers. Default is 128. **kwargs: other parameters for the RNN layer (e.g., rnn_type, num_layers, dropout, bidirectional). Examples: >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "conditions": ["cond-33", "cond-86"], # sequential ... "demographics": ["asian", "male"], # multi-hot ... "vitals": [120.0, 80.0, 98.6], # tensor ... "label": 1, ... }, ... { ... "patient_id": "patient-1", ... "visit_id": "visit-1", ... "conditions": ["cond-12", "cond-52"], # sequential ... "demographics": ["white", "female"], # multi-hot ... "vitals": [110.0, 75.0, 98.2], # tensor ... "label": 0, ... }, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "conditions": "sequence", ... "demographics": "multi_hot", ... "vitals": "tensor", ... }, ... output_schema={"label": "binary"}, ... dataset_name="test" ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> >>> model = MultimodalRNN(dataset=dataset, embedding_dim=128, hidden_dim=64) >>> >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(...), 'y_prob': tensor(...), 'y_true': tensor(...), 'logit': tensor(...) } """ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, hidden_dim: int = 128, **kwargs ): super(MultimodalRNN, self).__init__(dataset=dataset) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim # validate kwargs for RNN layer if "input_size" in kwargs: raise ValueError("input_size is determined by embedding_dim") if "hidden_size" in kwargs: raise ValueError("hidden_size is determined by hidden_dim") assert len(self.label_keys) == 1, "Only one label key is supported" self.label_key = self.label_keys[0] self.mode = self.dataset.output_schema[self.label_key] self.embedding_model = EmbeddingModel(dataset, embedding_dim) # Classify features as sequential or non-sequential self.sequential_features = [] self.non_sequential_features = [] self.rnn = nn.ModuleDict() for feature_key in self.feature_keys: processor = dataset.input_processors[feature_key] if self._is_sequential_processor(processor): self.sequential_features.append(feature_key) # Create RNN for this feature self.rnn[feature_key] = RNNLayer( input_size=embedding_dim, hidden_size=hidden_dim, **kwargs ) else: self.non_sequential_features.append(feature_key) # Calculate final concatenated dimension final_dim = (len(self.sequential_features) * hidden_dim + len(self.non_sequential_features) * embedding_dim) output_size = self.get_output_size() self.fc = nn.Linear(final_dim, output_size) def _is_sequential_processor(self, processor) -> bool: """Check if processor represents sequential data. Sequential processors are those that benefit from RNN processing, including sequences of codes and timeseries data. Note: StageNetProcessor and StageNetTensorProcessor are excluded as they are specialized for the StageNet model architecture and should be treated as non-sequential for standard RNN processing. Args: processor: The processor instance to check. Returns: bool: True if processor is sequential, False otherwise. """ return isinstance(processor, ( SequenceProcessor, NestedSequenceProcessor, DeepNestedSequenceProcessor, NestedFloatsProcessor, DeepNestedFloatsProcessor, TimeseriesProcessor, ))
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation handling mixed modalities. The label `kwargs[self.label_key]` is a list of labels for each patient. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Returns: Dict[str, torch.Tensor]: A dictionary with the following keys: - loss: a scalar tensor representing the loss. - y_prob: a tensor representing the predicted probabilities. - y_true: a tensor representing the true labels. - logit: a tensor representing the logits. - embed (optional): a tensor representing the patient embeddings if requested. """ # Preprocess features inputs = {} masks = {} for feature_key in self.feature_keys: feature = kwargs[feature_key] if isinstance(feature, torch.Tensor): feature = (feature,) schema = self.dataset.input_processors[feature_key].schema() value = feature[schema.index("value")] if "value" in schema else None mask = feature[schema.index("mask")] if "mask" in schema else None if value is None: raise ValueError(f"Feature '{feature_key}' must contain 'value' in the schema.") inputs[feature_key] = value if mask is not None: masks[feature_key] = mask patient_emb = [] embedded, mask = self.embedding_model(inputs, masks=masks, output_mask=True) # Process sequential features through RNN for feature_key in self.sequential_features: x = embedded[feature_key] m = mask[feature_key] x_dim_orig = x.dim() if x_dim_orig == 4: # nested_sequence: (B, num_visits, num_codes, D) # Pool codes within each visit, then run RNN over visits. # Flattening visits*codes would produce length=0 when inner lists are empty. x = x.sum(dim=2) # (B, num_visits, D) if feature_key in masks: m = (masks[feature_key].to(self.device).sum(dim=-1) > 0).int() # (B, V) else: m = (torch.abs(x).sum(dim=-1) != 0).int() elif x_dim_orig == 2: x = x.unsqueeze(1) m = None else: # 3D: already (B, T, D) if m is not None and m.dim() == 3: m = (m.sum(dim=-1) > 0).int() elif m is not None and m.dim() == 1: m = m.unsqueeze(1) _, last_hidden = self.rnn[feature_key](x, m) patient_emb.append(last_hidden) # Process non-sequential features (use embeddings directly) for feature_key in self.non_sequential_features: x = embedded[feature_key] # If multi-dimensional, aggregate (mean pooling) while x.dim() > 2: x = x.mean(dim=1) patient_emb.append(x) # Concatenate all representations patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # Calculate loss and predictions y_true = kwargs[self.label_key].to(self.device) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits } if kwargs.get("embed", False): results["embed"] = patient_emb return results