Source code for pyhealth.models.stagenet

from typing import Dict, Optional, Tuple, cast

import torch
import torch.nn as nn

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel
from pyhealth.models.utils import get_last_visit
from pyhealth.interpret.api import Interpretable

from .embedding import EmbeddingModel


[docs]class StageNetLayer(nn.Module): """StageNet layer. Paper: Stagenet: Stage-aware neural networks for health risk prediction. WWW 2020. This layer is used in the StageNet model. But it can also be used as a standalone layer. Args: input_dim: dynamic feature size. chunk_size: the chunk size for the StageNet layer. Default is 128. levels: the number of levels for the StageNet layer. levels * chunk_size = hidden_dim in the RNN. Smaller chunk size and more levels can capture more detailed patient status variations. Default is 3. conv_size: the size of the convolutional kernel. Default is 10. dropconnect: the dropout rate for the dropconnect. Default is 0.3. dropout: the dropout rate for the dropout. Default is 0.3. dropres: the dropout rate for the residual connection. Default is 0.3. Examples: >>> from pyhealth.models import StageNetLayer >>> input = torch.randn(3, 128, 64) # [batch size, sequence len, feature_size] >>> layer = StageNetLayer(64) >>> c, _, _ = layer(input) >>> c.shape torch.Size([3, 384]) """ def __init__( self, input_dim: int, chunk_size: int = 128, conv_size: int = 10, levels: int = 3, dropconnect: float = 0.3, dropout: float = 0.3, dropres: float = 0.3, ): super(StageNetLayer, self).__init__() self.dropout = dropout self.dropconnect = dropconnect self.dropres = dropres self.input_dim = input_dim self.hidden_dim = chunk_size * levels self.conv_dim = self.hidden_dim self.conv_size = conv_size # self.output_dim = output_dim self.levels = levels self.chunk_size = chunk_size self.kernel = nn.Linear( int(input_dim + 1), int(self.hidden_dim * 4 + levels * 2) ) nn.init.xavier_uniform_(self.kernel.weight) nn.init.zeros_(self.kernel.bias) self.recurrent_kernel = nn.Linear( int(self.hidden_dim + 1), int(self.hidden_dim * 4 + levels * 2) ) nn.init.orthogonal_(self.recurrent_kernel.weight) nn.init.zeros_(self.recurrent_kernel.bias) self.nn_scale = nn.Linear(int(self.hidden_dim), int(self.hidden_dim // 6)) self.nn_rescale = nn.Linear(int(self.hidden_dim // 6), int(self.hidden_dim)) self.nn_conv = nn.Conv1d( int(self.hidden_dim), int(self.conv_dim), int(conv_size), 1 ) # self.nn_output = nn.Linear(int(self.conv_dim), int(output_dim)) # Non-linearities exposed as modules for easy swapping self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.tanh = nn.Tanh() self.softmax = nn.Softmax(dim=-1) self.softmax_dim1 = nn.Softmax(dim=1) if self.dropconnect: self.nn_dropconnect = nn.Dropout(p=dropconnect) self.nn_dropconnect_r = nn.Dropout(p=dropconnect) if self.dropout: self.nn_dropout = nn.Dropout(p=dropout) self.nn_dropres = nn.Dropout(p=dropres) # Nonlinearities are plain modules; interpretability wrappers are applied # externally (e.g., DeepLIFT/GIM) by temporarily replacing these modules.
[docs] def cumax(self, x, mode="l2r"): if mode == "l2r": x = self.softmax(x) x = torch.cumsum(x, dim=-1) return x elif mode == "r2l": x = torch.flip(x, [-1]) x = self.softmax(x) x = torch.cumsum(x, dim=-1) return torch.flip(x, [-1]) else: return x
[docs] def step(self, inputs, c_last, h_last, interval, device): x_in = inputs.to(device=device) # Integrate inter-visit time intervals interval = interval.unsqueeze(-1).to(device=device) x_out1 = self.kernel(torch.cat((x_in, interval), dim=-1)).to(device) x_out2 = self.recurrent_kernel( torch.cat((h_last.to(device=device), interval), dim=-1) ) if self.dropconnect: x_out1 = self.nn_dropconnect(x_out1) x_out2 = self.nn_dropconnect_r(x_out2) x_out = x_out1 + x_out2 f_master_gate = self.cumax(x_out[:, : self.levels], "l2r") f_master_gate = f_master_gate.unsqueeze(2).to(device=device) i_master_gate = self.cumax(x_out[:, self.levels : self.levels * 2], "r2l") i_master_gate = i_master_gate.unsqueeze(2) x_out = x_out[:, self.levels * 2 :] x_out = x_out.reshape(-1, self.levels * 4, self.chunk_size) f_gate = self.sigmoid(x_out[:, : self.levels]).to(device=device) i_gate = self.sigmoid(x_out[:, self.levels : self.levels * 2]).to( device=device ) o_gate = self.sigmoid(x_out[:, self.levels * 2 : self.levels * 3]) c_in = self.tanh(x_out[:, self.levels * 3 :]).to(device=device) c_last = c_last.reshape(-1, self.levels, self.chunk_size).to(device=device) overlap = (f_master_gate * i_master_gate).to(device=device) c_out = ( overlap * (f_gate * c_last + i_gate * c_in) + (f_master_gate - overlap) * c_last + (i_master_gate - overlap) * c_in ) h_out = o_gate * self.tanh(c_out) c_out = c_out.reshape(-1, self.hidden_dim) h_out = h_out.reshape(-1, self.hidden_dim) out = torch.cat([h_out, f_master_gate[..., 0], i_master_gate[..., 0]], 1) return out, c_out, h_out
[docs] def forward( self, x: torch.Tensor, time: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, ...]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, input_dim]. static: a tensor of shape [batch size, static_dim]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Returns: last_output: a tensor of shape [batch size, chunk_size*levels] representing the patient embedding. outputs: a tensor of shape [batch size, sequence len, chunk_size*levels] representing the patient at each time step. """ # rnn will only apply dropout between layers batch_size, time_step, feature_dim = x.size() device = x.device if time is None: time = torch.ones(batch_size, time_step, device=device) time = time.reshape(batch_size, time_step) c_out = torch.zeros(batch_size, self.hidden_dim, device=device) h_out = torch.zeros(batch_size, self.hidden_dim, device=device) tmp_h = ( torch.zeros_like(h_out, dtype=torch.float32) .view(-1) .repeat(self.conv_size) .view(self.conv_size, batch_size, self.hidden_dim) ) tmp_dis = torch.zeros((self.conv_size, batch_size)) h = [] origin_h = [] distance = [] for t in range(time_step): out, c_out, h_out = self.step(x[:, t, :], c_out, h_out, time[:, t], device) cur_distance = 1 - torch.mean( out[..., self.hidden_dim : self.hidden_dim + self.levels], -1 ) origin_h.append(out[..., : self.hidden_dim]) tmp_h = torch.cat( ( tmp_h[1:].to(device=device), out[..., : self.hidden_dim].unsqueeze(0).to(device=device), ), 0, ) tmp_dis = torch.cat( ( tmp_dis[1:].to(device=device), cur_distance.unsqueeze(0).to(device=device), ), 0, ) distance.append(cur_distance) # Re-weighted convolution operation local_dis = tmp_dis.permute(1, 0) local_dis = torch.cumsum(local_dis, dim=1) local_dis = self.softmax_dim1(local_dis) local_h = tmp_h.permute(1, 2, 0) local_h = local_h * local_dis.unsqueeze(1) # Re-calibrate Progression patterns local_theme = torch.mean(local_h, dim=-1) local_theme = self.nn_scale(local_theme).to(device) local_theme = self.relu(local_theme) local_theme = self.nn_rescale(local_theme).to(device) local_theme = self.sigmoid(local_theme) local_h = self.nn_conv(local_h).squeeze(-1) local_h = local_theme * local_h h.append(local_h) origin_h = torch.stack(origin_h).permute(1, 0, 2) rnn_outputs = torch.stack(h).permute(1, 0, 2) if self.dropres > 0.0: origin_h = self.nn_dropres(origin_h) rnn_outputs = rnn_outputs + origin_h rnn_outputs = rnn_outputs.contiguous().view(-1, rnn_outputs.size(-1)) if self.dropout > 0.0: rnn_outputs = self.nn_dropout(rnn_outputs) output = rnn_outputs.contiguous().view(batch_size, time_step, self.hidden_dim) last_output = get_last_visit(output, mask) return last_output, output, torch.stack(distance)
[docs]class StageNet(BaseModel, Interpretable): """StageNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health risk prediction. WWW 2020. This model uses the StageNetProcessor which expects inputs in the format: {"value": [...], "time": [...]} The processor handles various input types: - Code sequences (with/without time intervals) - Nested code sequences (with/without time intervals) - Numeric feature vectors (with/without time intervals) Time intervals are optional and represent inter-event delays. If not provided, all events are treated as having uniform time intervals. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. embedding_dim: the embedding dimension. Default is 128. chunk_size: the chunk size for the StageNet layer. Default is 128. levels: the number of levels for the StageNet layer. levels * chunk_size = hidden_dim in the RNN. Smaller chunk_size and more levels can capture more detailed patient status variations. Default is 3. **kwargs: other parameters for the StageNet layer. Examples: >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "codes": { ... "value": ["505800458", "50580045810", "50580045811"], ... "time": [0.0, 2.0, 1.3], ... }, ... "procedures": { ... "value": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], ... "time": [0.0, 1.5], ... }, ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "codes": { ... "value": ["55154191800", "551541928", "55154192800"], ... "time": [0.0, 2.0, 1.3], ... }, ... "procedures": { ... "value": [["A04A", "B035", "C129"]], ... "time": [0.0], ... }, ... "label": 0, ... }, ... ] >>> >>> # dataset >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "codes": "stagenet", ... "procedures": "stagenet", ... }, ... output_schema={"label": "binary"}, ... dataset_name="test" ... ) >>> >>> # data loader >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> >>> # model >>> model = StageNet(dataset=dataset) >>> >>> # data batch >>> data_batch = next(iter(train_loader)) >>> >>> # try the model >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(...), 'y_prob': tensor(...), 'y_true': tensor(...), 'logit': tensor(...) } >>> """ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, chunk_size: int = 128, levels: int = 3, **kwargs, ): super(StageNet, self).__init__( dataset=dataset, ) self.embedding_dim = embedding_dim self.chunk_size = chunk_size self.levels = levels # validate kwargs for StageNet layer if "input_dim" in kwargs: raise ValueError("input_dim is determined by embedding_dim") assert len(self.label_keys) == 1, "Only one label key is supported" self.label_key = self.label_keys[0] self.mode = self.dataset.output_schema[self.label_key] # Use EmbeddingModel for unified embedding handling self.embedding_model = EmbeddingModel(dataset, embedding_dim) # Create StageNet layers for each feature self.stagenet = nn.ModuleDict() for feature_key in self.feature_keys: self.stagenet[feature_key] = StageNetLayer( input_dim=embedding_dim, chunk_size=self.chunk_size, levels=self.levels, **kwargs, ) output_size = self.get_output_size() self.fc = nn.Linear( len(self.feature_keys) * self.chunk_size * self.levels, output_size )
[docs] def forward_from_embedding( self, **kwargs: torch.Tensor | tuple[torch.Tensor, ...], ) -> Dict[str, torch.Tensor]: """Forward pass starting from feature embeddings. This method bypasses the embedding layers but still performs temporal processing through StageNet layers. This is useful for interpretability methods like Integrated Gradients that need to interpolate in embedding space. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. It is expected to contain to following semantic tensors: - "value": the embedded feature tensor of shape [batch, seq_len, embedding_dim] or [batch, seq_len, inner_len, embedding_dim] for nested sequences. - "time" (optional): the time intervals tensor of shape [batch, seq_len]. If not provided, uniform intervals will be assumed. - "mask" (optional): the mask tensor of shape [batch, seq_len] or [batch, seq_len, inner_len] for nested sequences. If not in the processor schema, it can be provided as the last element of the feature tuple. If not provided, masks will be generated from the embedded values (non-zero entries are treated as valid). The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: loss: a scalar tensor representing the final loss. y_prob: a tensor of predicted probabilities. y_true: a tensor representing the true labels. logit: the raw logits before activation. embed: (if embed=True in kwargs) the patient embedding. """ patient_emb = [] distance = [] for feature_key in self.feature_keys: processor = self.dataset.input_processors[feature_key] schema = processor.schema() feature = kwargs[feature_key] if isinstance(feature, torch.Tensor): # Backward compatibility: if feature is a tensor, treat it as values without time and mask feature = (feature,) time = feature[schema.index("time")] if "time" in schema else None value = feature[schema.index("value")] if "value" in schema else None mask = feature[schema.index("mask")] if "mask" in schema else None if len(feature) == len(schema) + 1 and mask is None: # An optional mask can be provided as the last element if not included in the schema mask = feature[-1] if value is None: raise ValueError(f"Feature '{feature_key}' must contain 'value' in the schema.") else: value = value.to(self.device) if time is None: import warnings warnings.warn( f"Feature '{feature_key}' does not have time " f"intervals. StageNet's temporal modeling " f"capabilities will be limited. Consider using " f"StageNet format with time intervals for " f"better performance.", UserWarning, ) else: time = time.to(self.device) # Ensure time is 2D [batch, seq_len] if time.dim() == 1: time = time.unsqueeze(0) if mask is None: import warnings warnings.warn( f"Feature '{feature_key}' does not have mask " f"information. Default mask will be created from " f"embedded values. But it may not be accurate.", ) mask = (value.abs().sum(dim=-1) != 0).int() elif not processor.is_token() and value.dim() == mask.dim(): # for continuous features, if mask is provided, # we need to collapse the feature dimension. mask = mask.any(dim=-1).int() else: mask = mask.to(self.device) if value.dim() == 4: # Nested sequences: [batch, seq_len, inner_len, embedding_dim] value = value.sum(dim=2) # Sum pool over inner dimension mask = mask.any(dim=2).int() # Update mask for nested sequences # Pass through StageNet layer with embedded features last_output, _, cur_dis = self.stagenet[feature_key]( value, time=time, mask=mask ) patient_emb.append(last_output) distance.append(cur_dis) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) y_prob = self.prepare_y_prob(logits) results = { "logit": logits, "y_prob": y_prob, } # obtain y_true, loss, y_prob if self.label_key in kwargs: y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) loss = self.get_loss_function()(logits, y_true) results["loss"] = loss results["y_true"] = y_true # Optionally return embeddings if kwargs.get("embed", False): results["embed"] = patient_emb return results
[docs] def forward( self, **kwargs: torch.Tensor | tuple[torch.Tensor, ...] ) -> Dict[str, torch.Tensor]: """Forward propagation. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Feature keys should contain tuples of tensors (time, values) from temporal processors. But the featurs keys can also contain just the values without time at the cost of degraded performance. The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: loss: a scalar tensor representing the final loss. distance: list of tensors of stage variation. y_prob: a tensor of predicted probabilities. y_true: a tensor representing the true labels. """ for feature_key in self.feature_keys: feature = kwargs[feature_key] if isinstance(feature, torch.Tensor): # Backward compatibility: if feature is a tensor, treat it as values without time feature = (feature,) schema = self.dataset.input_processors[feature_key].schema() value = feature[schema.index("value")] if "value" in schema else None if value is None: raise ValueError(f"Feature '{feature_key}' must contain 'value' in the schema.") else: value = value.to(self.device) value = self.embedding_model({feature_key: value})[feature_key] i = schema.index("value") kwargs[feature_key] = feature[:i] + (value,) + feature[i+1:] return self.forward_from_embedding(**kwargs)
[docs] def get_embedding_model(self) -> nn.Module | None: return self.embedding_model