Source code for pyhealth.models.retain

from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel
from pyhealth.processors import (
    DeepNestedFloatsProcessor,
    DeepNestedSequenceProcessor,
    MultiHotProcessor,
    NestedFloatsProcessor,
    NestedSequenceProcessor,
    SequenceProcessor,
    TensorProcessor,
    TimeseriesProcessor,
)

from .embedding import EmbeddingModel


[docs]class RETAINLayer(nn.Module): """RETAIN layer. Paper: Edward Choi et al. RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism. NIPS 2016. This layer is used in the RETAIN model. But it can also be used as a standalone layer. Args: feature_size: the hidden feature size. dropout: dropout rate. Default is 0.5. Examples: >>> from pyhealth.models import RETAINLayer >>> input = torch.randn(3, 128, 64) # [batch size, sequence len, feature_size] >>> layer = RETAINLayer(64) >>> c = layer(input) >>> c.shape torch.Size([3, 64]) """ def __init__( self, feature_size: int, dropout: float = 0.5, ): super(RETAINLayer, self).__init__() self.feature_size = feature_size self.dropout = dropout self.dropout_layer = nn.Dropout(p=self.dropout) self.alpha_gru = nn.GRU(feature_size, feature_size, batch_first=True) self.beta_gru = nn.GRU(feature_size, feature_size, batch_first=True) self.alpha_li = nn.Linear(feature_size, 1) self.beta_li = nn.Linear(feature_size, feature_size)
[docs] @staticmethod def reverse_x(input, lengths): """Reverses the input.""" reversed_input = input.new(input.size()) for i, length in enumerate(lengths): reversed_input[i, :length] = input[i, :length].flip(dims=[0]) return reversed_input
[docs] def compute_alpha(self, rx, lengths, total_length: int): """Computes alpha attention.""" rx = rnn_utils.pack_padded_sequence( rx, lengths, batch_first=True, enforce_sorted=False ) g, _ = self.alpha_gru(rx) g, _ = rnn_utils.pad_packed_sequence( g, batch_first=True, total_length=total_length ) attn_alpha = torch.softmax(self.alpha_li(g), dim=1) return attn_alpha
[docs] def compute_beta(self, rx, lengths, total_length: int): """Computes beta attention.""" rx = rnn_utils.pack_padded_sequence( rx, lengths, batch_first=True, enforce_sorted=False ) h, _ = self.beta_gru(rx) h, _ = rnn_utils.pad_packed_sequence( h, batch_first=True, total_length=total_length ) attn_beta = torch.tanh(self.beta_li(h)) return attn_beta
[docs] def forward( self, x: torch.tensor, mask: Optional[torch.tensor] = None, ) -> Tuple[torch.tensor, torch.tensor]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, feature_size]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Returns: c: a tensor of shape [batch size, feature_size] representing the context vector. """ # rnn will only apply dropout between layers x = self.dropout_layer(x) batch_size = x.size(0) total_length = x.size(1) # capture before packing so pad_packed restores it if mask is None: lengths = torch.full( size=(batch_size,), fill_value=total_length, dtype=torch.int64 ) else: lengths = torch.sum(mask.int(), dim=-1).cpu() lengths = lengths.clamp(min=1) # prevent zero-length crash in GRU rx = self.reverse_x(x, lengths) attn_alpha = self.compute_alpha(rx, lengths, total_length) attn_beta = self.compute_beta(rx, lengths, total_length) c = attn_alpha * attn_beta * x # (patient, sequence len, feature_size) c = torch.sum(c, dim=1) # (patient, feature_size) return c
[docs]class RETAIN(BaseModel): """RETAIN model. Paper: Edward Choi et al. RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism. NIPS 2016. This model uses separate RETAIN layers for different features and applies reverse time attention to capture temporal dependencies. It now uses the unified EmbeddingModel for handling various input types. The model supports various input types through processors: - SequenceProcessor: Code sequences (e.g., diagnosis codes) - NestedSequenceProcessor: Nested code sequences (visit histories) - TimeseriesProcessor: Time series features - NestedSequenceFloatsProcessor: Nested numerical sequences Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. embedding_dim: the embedding dimension. Default is 128. **kwargs: other parameters for the RETAIN layer. Examples: >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "conditions": [["A", "B"], ["C"]], ... "procedures": [["P1"], ["P2", "P3"]], ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "conditions": [["D"], ["E", "F"]], ... "procedures": [["P4"]], ... "label": 0, ... }, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "conditions": "nested_sequence", ... "procedures": "nested_sequence", ... }, ... output_schema={"label": "binary"}, ... dataset_name="test" ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> >>> model = RETAIN(dataset=dataset) >>> >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(...), 'y_prob': tensor(...), 'y_true': tensor(...), 'logit': tensor(...) } """ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, **kwargs, ): super(RETAIN, self).__init__( dataset=dataset, ) self.embedding_dim = embedding_dim # validate kwargs for RETAIN layer if "feature_size" in kwargs: raise ValueError("feature_size is determined by embedding_dim") assert len(self.label_keys) == 1, "Only one label key is supported" self.label_key = self.label_keys[0] self.mode = self.dataset.output_schema[self.label_key] # Use EmbeddingModel for unified embedding handling self.embedding_model = EmbeddingModel(dataset, embedding_dim) # Create RETAIN layers for each feature self.retain = nn.ModuleDict() for feature_key in self.feature_keys: self.retain[feature_key] = RETAINLayer(feature_size=embedding_dim, **kwargs) output_size = self.get_output_size() num_features = len(self.feature_keys) self.fc = nn.Linear(num_features * self.embedding_dim, output_size)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Returns: Dict[str, torch.Tensor]: A dictionary with the following keys: - loss: a scalar tensor representing the loss. - y_prob: a tensor representing the predicted probabilities. - y_true: a tensor representing the true labels. - logit: a tensor representing the logits. - embed (optional): patient embeddings if requested. """ patient_emb = [] embedded = self.embedding_model(kwargs) for feature_key in self.feature_keys: x = embedded[feature_key] # Handle different input dimensions # Case 1: 4D tensor from NestedSequenceProcessor # (batch, visits, events, embedding_dim) # Need to sum across events to get (batch, visits, embedding_dim) if len(x.shape) == 4: x = torch.sum(x, dim=2) # Sum across events within visit # Case 2: 3D tensor from SequenceProcessor or after summing # (batch, seq_len, embedding_dim) - already correct format elif len(x.shape) == 3: pass # Already correct format # Case 3: 2D tensor - shouldn't happen for RETAIN but handle it elif len(x.shape) == 2: x = x.unsqueeze(1) # Add seq dim: (batch, 1, embedding_dim) else: raise ValueError( f"Unexpected tensor shape {x.shape} for feature " f"{feature_key}" ) # Create mask: non-padding entries are valid # Check if all values in embedding dimension are zero (padding) # (batch_size, num_visits) mask = (x.abs().sum(dim=-1) > 0).float() x = self.retain[feature_key](x, mask) patient_emb.append(x) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob y_true = kwargs[self.label_key].to(self.device) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = patient_emb return results
if __name__ == "__main__": from pyhealth.datasets import create_sample_dataset samples = [ { "patient_id": "patient-0", "visit_id": "visit-0", "conditions": [["A", "B"], ["C", "D", "E"]], "procedures": [["P1"], ["P2", "P3"]], "drugs_hist": [[], ["D1", "D2"]], "label": 1, }, { "patient_id": "patient-0", "visit_id": "visit-1", "conditions": [["F"], ["G", "H"]], "procedures": [["P4", "P5"], ["P6"]], "drugs_hist": [["D3"], ["D4", "D5"]], "label": 0, }, ] # dataset dataset = create_sample_dataset( samples=samples, input_schema={ "conditions": "nested_sequence", "procedures": "nested_sequence", "drugs_hist": "nested_sequence", }, output_schema={"label": "binary"}, dataset_name="test", ) # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) # model model = RETAIN(dataset=dataset) # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()
[docs]class MultimodalRETAIN(BaseModel): """Multimodal RETAIN model for mixed sequential and non-sequential features. This model extends RETAIN to support mixed input modalities: - Sequential features (sequences, timeseries) go through RETAINLayer - Non-sequential features (multi-hot, tensor) bypass RETAIN, use embeddings directly The model automatically classifies input features based on their processor types: - Sequential processors (apply RETAINLayer): SequenceProcessor, NestedSequenceProcessor, DeepNestedSequenceProcessor, NestedFloatsProcessor, DeepNestedFloatsProcessor, TimeseriesProcessor - Non-sequential processors (embeddings only): MultiHotProcessor, TensorProcessor For sequential features, the model: 1. Embeds the input using EmbeddingModel 2. Applies RETAINLayer with reverse time attention mechanism 3. Extracts the patient representation For non-sequential features, the model: 1. Embeds the input using EmbeddingModel 2. Applies mean pooling if needed to reduce to 2D 3. Uses the embedding directly All feature representations are concatenated and passed through a final fully connected layer for predictions. Args: dataset (SampleDataset): the dataset to train the model. It is used to query certain information such as the set of all tokens and processor types. embedding_dim (int): the embedding dimension. Default is 128. **kwargs: other parameters for the RETAIN layer (e.g., dropout). Examples: >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "conditions": [["A", "B"], ["C"]], # nested sequence ... "demographics": ["asian", "male"], # multi-hot ... "vitals": [110.0, 75.0, 98.2], # tensor ... "label": 1, ... }, ... { ... "patient_id": "patient-1", ... "visit_id": "visit-1", ... "conditions": [["D"], ["E", "F"]], # nested sequence ... "demographics": ["white", "female"], # multi-hot ... "vitals": [120.0, 80.0, 98.6], # tensor ... "label": 0, ... }, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "conditions": "nested_sequence", ... "demographics": "multi_hot", ... "vitals": "tensor", ... }, ... output_schema={"label": "binary"}, ... dataset_name="test" ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> >>> model = MultimodalRETAIN(dataset=dataset) >>> >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(...), 'y_prob': tensor(...), 'y_true': tensor(...), 'logit': tensor(...) } """ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128, **kwargs): super(MultimodalRETAIN, self).__init__(dataset=dataset) self.embedding_dim = embedding_dim # validate kwargs for RETAIN layer if "feature_size" in kwargs: raise ValueError("feature_size is determined by embedding_dim") assert len(self.label_keys) == 1, "Only one label key is supported" self.label_key = self.label_keys[0] self.mode = self.dataset.output_schema[self.label_key] self.embedding_model = EmbeddingModel(dataset, embedding_dim) # Classify features as sequential or non-sequential self.sequential_features = [] self.non_sequential_features = [] self.retain = nn.ModuleDict() for feature_key in self.feature_keys: processor = dataset.input_processors[feature_key] if self._is_sequential_processor(processor): self.sequential_features.append(feature_key) # Create RETAIN layer for this feature self.retain[feature_key] = RETAINLayer( feature_size=embedding_dim, **kwargs ) else: self.non_sequential_features.append(feature_key) # Calculate final concatenated dimension final_dim = ( len(self.sequential_features) * embedding_dim + len(self.non_sequential_features) * embedding_dim ) output_size = self.get_output_size() self.fc = nn.Linear(final_dim, output_size) def _is_sequential_processor(self, processor) -> bool: """Check if processor represents sequential data. Sequential processors are those that benefit from RETAIN processing, including sequences of codes and timeseries data. Note: StageNetProcessor and StageNetTensorProcessor are excluded as they are specialized for the StageNet model architecture and should be treated as non-sequential for standard RETAIN processing. Args: processor: The processor instance to check. Returns: bool: True if processor is sequential, False otherwise. """ return isinstance( processor, ( SequenceProcessor, NestedSequenceProcessor, DeepNestedSequenceProcessor, NestedFloatsProcessor, DeepNestedFloatsProcessor, TimeseriesProcessor, ), )
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation handling mixed modalities. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Returns: Dict[str, torch.Tensor]: A dictionary with the following keys: - loss: a scalar tensor representing the loss. - y_prob: a tensor representing the predicted probabilities. - y_true: a tensor representing the true labels. - logit: a tensor representing the logits. - embed (optional): a tensor representing the patient embeddings if requested. """ patient_emb = [] embedded, emb_masks = self.embedding_model(kwargs, output_mask=True) # Process sequential features through RETAIN for feature_key in self.sequential_features: x = embedded[feature_key] # Handle different input dimensions # Case 1: 4D tensor from NestedSequenceProcessor # (batch, visits, events, embedding_dim) # Need to sum across events to get (batch, visits, embedding_dim) if x.dim() == 4: x = torch.sum(x, dim=2) # Sum across events within visit # Case 2: 3D tensor from SequenceProcessor or after summing # (batch, seq_len, embedding_dim) - already correct format elif x.dim() == 3: pass # Already correct format # Case 3: 2D tensor - shouldn't happen for RETAIN but handle it elif x.dim() == 2: x = x.unsqueeze(1) # Add seq dim: (batch, 1, embedding_dim) else: raise ValueError( f"Unexpected tensor shape {x.shape} for feature {feature_key}" ) # Use mask from EmbeddingModel (derived from original unembedded tensor) mask = emb_masks.get(feature_key) if mask is not None: # Ensure 2D (batch, seq_len) — reduce any extra dims while mask.dim() > 2: mask = mask.any(dim=-1) mask = mask.float() x = self.retain[feature_key](x, mask) patient_emb.append(x) # Process non-sequential features (use embeddings directly) for feature_key in self.non_sequential_features: x = embedded[feature_key] # If multi-dimensional, aggregate (mean pooling) while x.dim() > 2: x = x.mean(dim=1) patient_emb.append(x) # Concatenate all representations patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # Calculate loss and predictions y_true = kwargs[self.label_key].to(self.device) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = patient_emb return results