Source code for pyhealth.models.stagenet_mha

from typing import Dict, Optional, Tuple, cast
import warnings

import torch
import torch.nn as nn

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel
from pyhealth.models.utils import get_last_visit
from .transformer import MultiHeadedAttention
from pyhealth.interpret.api import CheferInterpretable

from .embedding import EmbeddingModel


[docs]class StageNetAttentionLayer(nn.Module): """StageNetAttention layer. Paper: Stagenet: Stage-aware neural networks for health risk prediction. WWW 2020. This layer is used in the StageAttentionNet model. But it can also be used as a standalone layer. Args: input_dim: dynamic feature size. chunk_size: the chunk size for the StageNet layer. Default is 128. levels: the number of levels for the StageNet layer. levels * chunk_size = hidden_dim in the RNN. Smaller chunk size and more levels can capture more detailed patient status variations. Default is 3. conv_size: the size of the convolutional kernel. Default is 10. dropconnect: the dropout rate for the dropconnect. Default is 0.3. dropout: the dropout rate for the dropout. Default is 0.3. dropres: the dropout rate for the residual connection. Default is 0.3. num_heads: number of heads in the multi-head attention inserted between the SA-LSTM and the stage-adaptive CNN. Default is 8. attn_dropout: dropout rate applied to attention weights. Default is 0.1. Examples: >>> from pyhealth.models import StageNetAttentionLayer >>> input = torch.randn(3, 128, 64) # [batch size, sequence len, feature_size] >>> layer = StageNetAttentionLayer(64) >>> c, _, _ = layer(input) >>> c.shape torch.Size([3, 384]) """ def __init__( self, input_dim: int, chunk_size: int = 128, conv_size: int = 10, levels: int = 3, dropconnect: float = 0.3, dropout: float = 0.3, dropres: float = 0.3, num_heads: int = 8, attn_dropout: float = 0.1, ): super(StageNetAttentionLayer, self).__init__() self.dropout = dropout self.dropconnect = dropconnect self.dropres = dropres self.input_dim = input_dim self.hidden_dim = chunk_size * levels self.conv_dim = self.hidden_dim self.conv_size = conv_size # self.output_dim = output_dim self.levels = levels self.chunk_size = chunk_size self.kernel = nn.Linear( int(input_dim + 1), int(self.hidden_dim * 4 + levels * 2) ) nn.init.xavier_uniform_(self.kernel.weight) nn.init.zeros_(self.kernel.bias) self.recurrent_kernel = nn.Linear( int(self.hidden_dim + 1), int(self.hidden_dim * 4 + levels * 2) ) nn.init.orthogonal_(self.recurrent_kernel.weight) nn.init.zeros_(self.recurrent_kernel.bias) self.nn_scale = nn.Linear(int(self.hidden_dim), int(self.hidden_dim // 6)) self.nn_rescale = nn.Linear(int(self.hidden_dim // 6), int(self.hidden_dim)) self.nn_conv = nn.Conv1d( int(self.hidden_dim), int(self.conv_dim), int(conv_size), 1 ) # Non-linearities are defined as modules so they can be swapped # wholesale (e.g., for DeepLIFT/GIM instrumentation). self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.tanh = nn.Tanh() self.softmax = nn.Softmax(dim=-1) self.softmax_dim1 = nn.Softmax(dim=1) if self.hidden_dim % num_heads != 0: raise ValueError( f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({num_heads})" ) # Use the Transformer-style attention to capture per-head maps + grads self.mha = MultiHeadedAttention( h=num_heads, d_model=self.hidden_dim, dropout=attn_dropout ) self.attn_norm = nn.LayerNorm(self.hidden_dim) self.attn_gradients = None self.attn_map = None # self.nn_output = nn.Linear(int(self.conv_dim), int(output_dim)) if self.dropconnect: self.nn_dropconnect = nn.Dropout(p=dropconnect) self.nn_dropconnect_r = nn.Dropout(p=dropconnect) if self.dropout: self.nn_dropout = nn.Dropout(p=dropout) self.nn_dropres = nn.Dropout(p=dropres) # Nonlinearities are plain modules; interpretability wrappers are # applied externally (e.g., DeepLIFT/GIM) by temporarily replacing # these modules at runtime.
[docs] def get_attn_map(self) -> Optional[torch.Tensor]: """Return the last attention weight map from the MHA block.""" if hasattr(self.mha, "get_attn_map"): return self.mha.get_attn_map() return self.attn_map
[docs] def get_attn_grad(self) -> Optional[torch.Tensor]: """Return gradients captured from the attention weights.""" if hasattr(self.mha, "get_attn_grad"): return self.mha.get_attn_grad() return self.attn_gradients
[docs] def save_attn_grad(self, attn_grad: torch.Tensor) -> None: """Hook callback that stores attention gradients.""" self.attn_gradients = attn_grad
[docs] def cumax(self, x, mode="l2r"): if mode == "l2r": x = self.softmax(x) x = torch.cumsum(x, dim=-1) return x elif mode == "r2l": x = torch.flip(x, [-1]) x = self.softmax(x) x = torch.cumsum(x, dim=-1) return torch.flip(x, [-1]) else: return x
[docs] def step(self, inputs, c_last, h_last, interval, device): x_in = inputs.to(device=device) # Integrate inter-visit time intervals interval = interval.unsqueeze(-1).to(device=device) x_out1 = self.kernel(torch.cat((x_in, interval), dim=-1)).to(device) x_out2 = self.recurrent_kernel( torch.cat((h_last.to(device=device), interval), dim=-1) ) if self.dropconnect: x_out1 = self.nn_dropconnect(x_out1) x_out2 = self.nn_dropconnect_r(x_out2) x_out = x_out1 + x_out2 f_master_gate = self.cumax(x_out[:, : self.levels], "l2r") f_master_gate = f_master_gate.unsqueeze(2).to(device=device) i_master_gate = self.cumax(x_out[:, self.levels : self.levels * 2], "r2l") i_master_gate = i_master_gate.unsqueeze(2) x_out = x_out[:, self.levels * 2 :] x_out = x_out.reshape(-1, self.levels * 4, self.chunk_size) f_gate = self.sigmoid(x_out[:, : self.levels]).to(device=device) i_gate = self.sigmoid(x_out[:, self.levels : self.levels * 2]).to( device=device ) o_gate = self.sigmoid(x_out[:, self.levels * 2 : self.levels * 3]) c_in = self.tanh(x_out[:, self.levels * 3 :]).to(device=device) c_last = c_last.reshape(-1, self.levels, self.chunk_size).to(device=device) overlap = (f_master_gate * i_master_gate).to(device=device) c_out = ( overlap * (f_gate * c_last + i_gate * c_in) + (f_master_gate - overlap) * c_last + (i_master_gate - overlap) * c_in ) h_out = o_gate * self.tanh(c_out) c_out = c_out.reshape(-1, self.hidden_dim) h_out = h_out.reshape(-1, self.hidden_dim) out = torch.cat([h_out, f_master_gate[..., 0], i_master_gate[..., 0]], 1) return out, c_out, h_out
[docs] def forward( self, x: torch.Tensor, time: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, register_hook: bool = False, ) -> Tuple[torch.Tensor, ...]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, input_dim]. static: a tensor of shape [batch size, static_dim]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. register_hook: whether to register a backward hook on attention weights for gradient inspection. Returns: last_output: a tensor of shape [batch size, chunk_size*levels] representing the patient embedding. outputs: a tensor of shape [batch size, sequence len, chunk_size*levels] representing the patient at each time step. """ # rnn will only apply dropout between layers batch_size, time_step, feature_dim = x.size() device = x.device if time == None: time = torch.ones(batch_size, time_step, device=device) time = time.reshape(batch_size, time_step) c_out = torch.zeros(batch_size, self.hidden_dim, device=device) h_out = torch.zeros(batch_size, self.hidden_dim, device=device) hidden_states = [] distance = [] for t in range(time_step): out, c_out, h_out = self.step(x[:, t, :], c_out, h_out, time[:, t], device) cur_distance = 1 - torch.mean( out[..., self.hidden_dim : self.hidden_dim + self.levels], -1 ) hidden_states.append(out[..., : self.hidden_dim]) distance.append(cur_distance) # shape: [time, batch, hidden_dim] hidden_seq = torch.stack(hidden_states) distance = torch.stack(distance) key_padding_mask = None if mask is not None: key_padding_mask = (mask == 0).to(device=device, dtype=torch.bool) # Capture per-head attention weights for interpretability # Prepare mask for Transformer-style attention: 1=keep, 0=mask attn_mask = None if key_padding_mask is not None: valid = (~key_padding_mask).float() # [batch, time] attn_mask = valid.unsqueeze(1) * valid.unsqueeze(2) # [batch, time, time] seq_for_mha = hidden_seq.permute(1, 0, 2) # [batch, time, hidden] attn_output = self.mha( seq_for_mha, seq_for_mha, seq_for_mha, mask=attn_mask, register_hook=register_hook ) self.attn_map = self.get_attn_map() self.attn_gradients = None # will be populated after backward if hooked attn_output = attn_output.transpose(0, 1) # back to [time, batch, hidden] attn_output = self.attn_norm(attn_output + hidden_seq) tmp_h = torch.zeros( (self.conv_size, batch_size, self.hidden_dim), device=device ) tmp_dis = torch.zeros((self.conv_size, batch_size), device=device) conv_outputs = [] for t in range(time_step): cur_h = attn_output[t] cur_distance = distance[t] tmp_h = torch.cat((tmp_h[1:], cur_h.unsqueeze(0)), 0) tmp_dis = torch.cat((tmp_dis[1:], cur_distance.unsqueeze(0)), 0) # Re-weighted convolution operation local_dis = tmp_dis.permute(1, 0) local_dis = torch.cumsum(local_dis, dim=1) local_dis = self.softmax_dim1(local_dis) local_h = tmp_h.permute(1, 2, 0) local_h = local_h * local_dis.unsqueeze(1) # Re-calibrate Progression patterns local_theme = torch.mean(local_h, dim=-1) local_theme = self.nn_scale(local_theme) local_theme = self.relu(local_theme) local_theme = self.nn_rescale(local_theme) local_theme = self.sigmoid(local_theme) local_h = self.nn_conv(local_h).squeeze(-1) local_h = local_theme * local_h conv_outputs.append(local_h) origin_h = attn_output.permute(1, 0, 2) rnn_outputs = torch.stack(conv_outputs).permute(1, 0, 2) if self.dropres > 0.0: origin_h = self.nn_dropres(origin_h) rnn_outputs = rnn_outputs + origin_h rnn_outputs = rnn_outputs.contiguous().view(-1, rnn_outputs.size(-1)) if self.dropout > 0.0: rnn_outputs = self.nn_dropout(rnn_outputs) output = rnn_outputs.contiguous().view(batch_size, time_step, self.hidden_dim) last_output = get_last_visit(output, mask) return last_output, output, distance
[docs]class StageAttentionNet(BaseModel, CheferInterpretable): """StageAttentionNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health risk prediction. WWW 2020. But with Multi-Head Attention (MHA) between the SA-LSTM and the SA-CNN. This model applies to the dataset which expects inputs in the format: {"value": [...], "time": [...]} The processor handles various input types: - Code sequences (with/without time intervals) - Nested code sequences (with/without time intervals) - Numeric feature vectors (with/without time intervals) Time intervals are optional and represent inter-event delays. If not provided, all events are treated as having uniform time intervals. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. embedding_dim: the embedding dimension. Default is 128. chunk_size: the chunk size for the StageNetAttentionLayer. Default is 128. levels: the number of levels for the StageNetAttentionLayer. levels * chunk_size = hidden_dim in the RNN. Smaller chunk_size and more levels can capture more detailed patient status variations. Default is 3. **kwargs: other parameters for the StageNetAttentionLayer. Examples: >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "codes": { ... "value": ["505800458", "50580045810", "50580045811"], ... "time": [0.0, 2.0, 1.3], ... }, ... "procedures": { ... "value": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], ... "time": [0.0, 1.5], ... }, ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "codes": { ... "value": ["55154191800", "551541928", "55154192800"], ... "time": [0.0, 2.0, 1.3], ... }, ... "procedures": { ... "value": [["A04A", "B035", "C129"]], ... "time": [0.0], ... }, ... "label": 0, ... }, ... ] >>> >>> # dataset >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={ ... "codes": "stagenet", ... "procedures": "stagenet", ... }, ... output_schema={"label": "binary"}, ... dataset_name="test" ... ) >>> >>> # data loader >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> >>> # model >>> model = StageAttentionNet(dataset=dataset) >>> >>> # data batch >>> data_batch = next(iter(train_loader)) >>> >>> # try the model >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(...), 'y_prob': tensor(...), 'y_true': tensor(...), 'logit': tensor(...) } >>> """ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, chunk_size: int = 128, levels: int = 3, **kwargs, ): super(StageAttentionNet, self).__init__( dataset=dataset, ) self.embedding_dim = embedding_dim self.chunk_size = chunk_size self.levels = levels self._attention_hooks_enabled = False # validate kwargs for StageNet layer if "input_dim" in kwargs: raise ValueError("input_dim is determined by embedding_dim") assert len(self.label_keys) == 1, "Only one label key is supported" self.label_key = self.label_keys[0] self.mode = self.dataset.output_schema[self.label_key] # Use EmbeddingModel for unified embedding handling self.embedding_model = EmbeddingModel(dataset, embedding_dim) # Create StageNet layers for each feature self.stagenet = nn.ModuleDict() for feature_key in self.feature_keys: self.stagenet[feature_key] = StageNetAttentionLayer( input_dim=embedding_dim, chunk_size=self.chunk_size, levels=self.levels, **kwargs, ) output_size = self.get_output_size() self.fc = nn.Linear( len(self.feature_keys) * self.chunk_size * self.levels, output_size )
[docs] def forward_from_embedding( self, **kwargs: torch.Tensor | tuple[torch.Tensor, ...], ) -> Dict[str, torch.Tensor]: """Forward pass starting from feature embeddings. This method bypasses the embedding layers but still performs temporal processing through StageNet layers. This is useful for interpretability methods like Integrated Gradients that need to interpolate in embedding space. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. It is expected to contain the following semantic tensors: - "value": the embedded feature tensor of shape [batch, seq_len, embedding_dim] or [batch, seq_len, inner_len, embedding_dim] for nested sequences. - "time" (optional): the time intervals tensor of shape [batch, seq_len]. If not provided, uniform intervals will be assumed. - "mask" (optional): the mask tensor of shape [batch, seq_len] or [batch, seq_len, inner_len] for nested sequences. If not in the processor schema, it can be provided as the last element of the feature tuple. If not provided, masks will be generated from the embedded values (non-zero entries are treated as valid). The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: loss: a scalar tensor representing the final loss. y_prob: a tensor of predicted probabilities. y_true: a tensor representing the true labels. logit: the raw logits before activation. embed: (if embed=True in kwargs) the patient embedding. """ # Support both the flag-based API and legacy kwarg-based API register_attn_hook = self._attention_hooks_enabled patient_emb = [] distance = [] for feature_key in self.feature_keys: processor = self.dataset.input_processors[feature_key] schema = processor.schema() feature = kwargs[feature_key] if isinstance(feature, torch.Tensor): # Backward compatibility: if feature is a tensor, treat it # as values without time and mask feature = (feature,) time = feature[schema.index("time")] if "time" in schema else None value = feature[schema.index("value")] if "value" in schema else None mask = feature[schema.index("mask")] if "mask" in schema else None if len(feature) == len(schema) + 1 and mask is None: # An optional mask can be provided as the last element # if not included in the schema mask = feature[-1] if value is None: raise ValueError( f"Feature '{feature_key}' must contain 'value' " f"in the schema." ) else: value = value.to(self.device) if time is None: warnings.warn( f"Feature '{feature_key}' does not have time " f"intervals. StageNet's temporal modeling " f"capabilities will be limited. Consider using " f"StageNet format with time intervals for " f"better performance.", UserWarning, ) else: time = time.to(self.device) # Ensure time is 2D [batch, seq_len] if time.dim() == 1: time = time.unsqueeze(0) if mask is None: warnings.warn( f"Feature '{feature_key}' does not have mask " f"information. Default mask will be created from " f"embedded values. But it may not be accurate.", ) mask = (value.abs().sum(dim=-1) != 0).int() elif not processor.is_token() and value.dim() == mask.dim(): # for continuous features, if mask is provided, # we need to collapse the feature dimension. mask = mask.any(dim=-1).int() else: mask = mask.to(self.device) if value.dim() == 4: # Nested sequences: [batch, seq_len, inner_len, embedding_dim] value = value.sum(dim=2) # Sum pool over inner dimension mask = mask.any(dim=2).int() # Update mask for nested sequences # Pass through StageNet layer with embedded features last_output, _, cur_dis = self.stagenet[feature_key]( value, time=time, mask=mask, register_hook=register_attn_hook ) patient_emb.append(last_output) distance.append(cur_dis) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) y_prob = self.prepare_y_prob(logits) results = { "logit": logits, "y_prob": y_prob, } # obtain y_true, loss, y_prob if self.label_key in kwargs: y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) loss = self.get_loss_function()(logits, y_true) results["loss"] = loss results["y_true"] = y_true # Optionally return embeddings if kwargs.get("embed", False): results["embed"] = patient_emb return results
[docs] def forward( self, **kwargs: torch.Tensor | tuple[torch.Tensor, ...], ) -> Dict[str, torch.Tensor]: """Forward propagation. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Feature keys should contain tuples of tensors (time, values) from temporal processors. But the feature keys can also contain just the values without time at the cost of degraded performance. The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: loss: a scalar tensor representing the final loss. distance: list of tensors of stage variation. y_prob: a tensor of predicted probabilities. y_true: a tensor representing the true labels. """ register_attn_hook = kwargs.pop("register_attn_hook", False) if register_attn_hook: kwargs["register_attn_hook"] = register_attn_hook # type: ignore for feature_key in self.feature_keys: feature = kwargs[feature_key] if isinstance(feature, torch.Tensor): # Backward compatibility: if feature is a tensor, treat it # as values without time feature = (feature,) schema = self.dataset.input_processors[feature_key].schema() value = feature[schema.index("value")] if "value" in schema else None if value is None: raise ValueError( f"Feature '{feature_key}' must contain 'value' " f"in the schema." ) else: value = value.to(self.device) value = self.embedding_model({feature_key: value})[feature_key] i = schema.index("value") kwargs[feature_key] = feature[:i] + (value,) + feature[i + 1:] return self.forward_from_embedding(**kwargs)
[docs] def get_embedding_model(self) -> nn.Module | None: return self.embedding_model
# ------------------------------------------------------------------ # CheferInterpretable interface # ------------------------------------------------------------------
[docs] def set_attention_hooks(self, enabled: bool) -> None: self._attention_hooks_enabled = enabled
[docs] def get_attention_layers( self, ) -> dict[str, list[tuple[torch.Tensor, torch.Tensor]]]: return { # type: ignore[return-value] key: [ ( cast(StageNetAttentionLayer, self.stagenet[key]).get_attn_map(), cast(StageNetAttentionLayer, self.stagenet[key]).get_attn_grad(), ) ] for key in self.feature_keys }
[docs] def get_relevance_tensor( self, R: dict[str, torch.Tensor], **data: torch.Tensor | tuple[torch.Tensor, ...], ) -> dict[str, torch.Tensor]: # StageAttentionNet uses get_last_visit (last valid timestep) # instead of a CLS token. Derive the mask from **data using # the same logic as forward_from_embedding. result = {} for key in self.feature_keys: r = R[key] batch_size = r.shape[0] device = r.device processor = self.dataset.input_processors[key] schema = processor.schema() feature = data[key] if isinstance(feature, torch.Tensor): feature = (feature,) value = feature[schema.index("value")] if "value" in schema else None mask = feature[schema.index("mask")] if "mask" in schema else None if len(feature) == len(schema) + 1 and mask is None: mask = feature[-1] if mask is None: if value is not None: v = value.to(device) mask = (v.abs().sum(dim=-1) != 0).int() else: # Cannot determine mask; fall back to last position last_idx = torch.full( (batch_size,), r.shape[1] - 1, device=device, dtype=torch.long, ) result[key] = r[ torch.arange(batch_size, device=device), last_idx ] continue else: mask = mask.to(device) if not processor.is_token() and value is not None and value.dim() == mask.dim(): mask = mask.any(dim=-1).int() if mask.dim() == 3: # Nested sequences: collapse inner dimension mask = mask.any(dim=2).int() last_idx = mask.sum(dim=1).long() - 1 last_idx = last_idx.clamp(min=0) attn = r[ torch.arange(batch_size, device=device), last_idx ] # [batch, attention_seq_len] result[key] = attn return result