Source code for pyhealth.models.transformer

# Author: Yongda Fan
# NetID: yongdaf2
# Description: Transformer model implementation for PyHealth 2.0

import math
import warnings
from typing import Any, Dict, Optional, Tuple, Union, cast

import torch
from torch import nn

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel
from pyhealth.models.embedding import EmbeddingModel
from pyhealth.interpret.api import CheferInterpretable

# VALID_OPERATION_LEVEL = ["visit", "event"]


class Attention(nn.Module):
    """Scaled dot-product attention helper."""

    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        dropout: Optional[nn.Module] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute attention outputs.

        Args:
            query: Query tensor ``[batch, heads, len_q, dim]``.
            key: Key tensor ``[batch, heads, len_k, dim]``.
            value: Value tensor ``[batch, heads, len_v, dim]``.
            mask: Optional boolean mask aligned to key/value lengths.
            dropout: Optional dropout module applied to attention weights.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Attention-applied values and the
            attention weight matrix.

        Example:
            Called inside :class:`MultiHeadedAttention` for each head.
        """

        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        if mask is not None:
            # Use -inf so softmax produces exact zeros on padded positions,
            # avoiding a second masked_fill after softmax (saves one full
            # [B, H, S, S] boolean allocation and an extra copy).
            pad_mask = (mask == 0)
            scores = scores.masked_fill(pad_mask, -1e9)
        p_attn = self.softmax(scores)
        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn


class MultiHeadedAttention(nn.Module):
    """Multi-head attention wrapper used by the Transformer block."""

    def __init__(self, h: int, d_model: int, dropout: float = 0.1):
        """Initialize the attention module.

        Args:
            h: Number of attention heads.
            d_model: Dimensionality of the model embedding.
            dropout: Dropout probability applied to attention weights.
        """

        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList(
            [nn.Linear(d_model, d_model, bias=False) for _ in range(3)]
        )
        self.output_linear = nn.Linear(d_model, d_model, bias=False)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

        self.attn_gradients = None
        self.attn_map = None

    def __repr__(self) -> str:
        return (
            f"MultiHeadedAttention(heads={self.h}, d_model={self.h * self.d_k}, "
            f"dropout={self.dropout.p})"
        )

    def set_activation_hooks(self, hooks) -> None:
        """Deprecated: retained for backward compatibility; no-op."""
        return None

    # helper functions for interpretability
    def get_attn_map(self) -> Optional[torch.Tensor]:
        """Return the last computed attention weights."""

        return self.attn_map

    def get_attn_grad(self) -> Optional[torch.Tensor]:
        """Return gradients captured from attention weights."""

        return self.attn_gradients

    def save_attn_grad(self, attn_grad: torch.Tensor) -> None:
        """Hook callback that stores attention gradients."""

        self.attn_gradients = attn_grad.detach()

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        register_hook: bool = False,
    ) -> torch.Tensor:
        """Run multi-head attention with optional gradient capture.

        Args:
            query: Query tensor ``[batch, len_q, hidden]`` or similar.
            key: Key tensor aligned with ``query``.
            value: Value tensor aligned with ``query``.
            mask: Optional boolean mask ``[batch, len_q, len_k]``.
            register_hook: True to attach a backward hook saving gradients.

        Returns:
            torch.Tensor: Attention mixed representation ``[batch, len_q, hidden]``.
        """

        batch_size = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
            for l, x in zip(self.linear_layers, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        if mask is not None:
            mask = mask.unsqueeze(1)
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        if register_hook:
            # Only store attn_map and hook during interpretability passes.
            # Using .detach() gives an independent copy whose storage
            # is NOT shared with the live graph, so the graph can be freed
            # normally after .backward() without leaking GPU memory.
            self.attn_map = attn.detach()
            attn.register_hook(self.save_attn_grad)
        else:
            self.attn_map = None
        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
  
        return self.output_linear(x)


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        """Construct the two-layer feed-forward sub-network.

        Args:
            d_model: Input and output dimensionality.
            d_ff: Hidden dimensionality of the intermediate linear layer.
            dropout: Dropout rate between the linear layers.
        """

        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Apply feed-forward transformation and optional masking."""

        x = self.w_2(self.dropout(self.activation(self.w_1(x))))
        if mask is not None:
            mask = mask.sum(dim=-1) > 0
            x[~mask] = 0
        return x


class SublayerConnection(nn.Module):
    def __init__(self, size: int, dropout: float):
        """Set up the pre-norm residual connection.

        Args:
            size: Feature dimensionality for layer normalization.
            dropout: Dropout probability applied to the sublayer output.
        """

        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        sublayer,
    ) -> torch.Tensor:
        """Apply pre-norm residual connection around a sublayer."""

        return x + self.dropout(sublayer(self.norm(x)))


[docs]class TransformerBlock(nn.Module): """Transformer block. MultiHeadedAttention + PositionwiseFeedForward + SublayerConnection Args: hidden: hidden size of transformer. attn_heads: head sizes of multi-head attention. dropout: dropout rate. """ def __init__(self, hidden, attn_heads, dropout): super(TransformerBlock, self).__init__() self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden) self.feed_forward = PositionwiseFeedForward( d_model=hidden, d_ff=4 * hidden, dropout=dropout ) self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout) self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout) self.dropout = nn.Dropout(p=dropout)
[docs] def set_activation_hooks(self, hooks) -> None: """Deprecated compatibility stub; no-op.""" return None
[docs] def forward(self, x, mask=None, register_hook = False): """Forward propagation. Args: x: [batch_size, seq_len, hidden] mask: [batch_size, seq_len, seq_len] Returns: A tensor of shape [batch_size, seq_len, hidden] """ x = self.input_sublayer(x, lambda _x: self.attention(_x, _x, _x, mask=mask, register_hook=register_hook)) x = self.output_sublayer(x, lambda _x: self.feed_forward(_x, mask=mask)) return self.dropout(x)
[docs]class TransformerLayer(nn.Module): """Transformer layer. Paper: Ashish Vaswani et al. Attention is all you need. NIPS 2017. This layer is used in the Transformer model. But it can also be used as a standalone layer. Args: feature_size: the hidden feature size. heads: the number of attention heads. Default is 1. dropout: dropout rate. Default is 0.5. num_layers: number of transformer layers. Default is 1. register_hook: True to save gradients of attention layer, Default is False. Examples: >>> from pyhealth.models import TransformerLayer >>> input = torch.randn(3, 128, 64) # [batch size, sequence len, feature_size] >>> layer = TransformerLayer(64) >>> emb, cls_emb = layer(input) >>> emb.shape torch.Size([3, 128, 64]) >>> cls_emb.shape torch.Size([3, 64]) """ def __init__(self, feature_size, heads=1, dropout=0.5, num_layers=1): super(TransformerLayer, self).__init__() self.transformer = nn.ModuleList( [TransformerBlock(feature_size, heads, dropout) for _ in range(num_layers)] )
[docs] def set_activation_hooks(self, hooks) -> None: """Deprecated compatibility stub; no-op.""" return None
[docs] def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, register_hook: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, feature_size]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Returns: emb: a tensor of shape [batch size, sequence len, feature_size], containing the output features for each time step. cls_emb: a tensor of shape [batch size, feature_size], containing the output features for the first time step. """ if mask is not None: mask = torch.einsum("ab,ac->abc", mask, mask) for transformer in self.transformer: x = transformer(x, mask, register_hook) emb = x cls_emb = x[:, 0, :] return emb, cls_emb
[docs]class Transformer(BaseModel, CheferInterpretable): """Transformer model for PyHealth 2.0 datasets. Each feature stream is embedded with :class:`EmbeddingModel` and encoded by an independent :class:`TransformerLayer`. The resulting [CLS]-style embeddings are concatenated and passed to a classification head. Args: dataset (SampleDataset): dataset providing processed inputs. embedding_dim (int): shared embedding dimension. heads (int): number of attention heads per transformer block. dropout (float): dropout rate applied inside transformer blocks. num_layers (int): number of transformer blocks per feature stream. Examples: >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "diagnoses": ["A", "B", "C"], ... "procedures": ["X", "Y"], ... "label": 1, ... }, ... { ... "patient_id": "patient-1", ... "visit_id": "visit-0", ... "diagnoses": ["D"], ... "procedures": ["Z", "Y"], ... "label": 0, ... }, ... ] >>> input_schema = {"diagnoses": "sequence", "procedures": "sequence"} >>> output_schema = {"label": "binary"} >>> dataset = create_sample_dataset( ... samples, ... input_schema, ... output_schema, ... dataset_name="demo", ... ) >>> model = Transformer(dataset=dataset) >>> loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> batch = next(iter(loader)) >>> output = model(**batch) >>> sorted(output.keys()) ['logit', 'loss', 'y_prob', 'y_true'] """ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, heads: int = 1, dropout: float = 0.5, num_layers: int = 1, max_seq_len: int = 1024, ): super().__init__(dataset=dataset) self.embedding_dim = embedding_dim self.heads = heads self.dropout = dropout self.num_layers = num_layers self.max_seq_len = max_seq_len self._attention_hooks_enabled = False assert ( len(self.label_keys) == 1 ), "Only one label key is supported if Transformer is initialized" self.label_key = self.label_keys[0] self.mode = self.dataset.output_schema[self.label_key] self.embedding_model = EmbeddingModel(dataset, embedding_dim) self.transformer: nn.ModuleDict = nn.ModuleDict() for feature_key in self.feature_keys: self.transformer[feature_key] = TransformerLayer( feature_size=embedding_dim, heads=heads, dropout=dropout, num_layers=num_layers, ) output_size = self.get_output_size() self.fc = nn.Linear(len(self.feature_keys) * embedding_dim, output_size) def _pool_embedding(self, x: torch.Tensor) -> torch.Tensor: """Pool nested embeddings to ``[batch, seq_len, hidden]`` format. Args: x: Tensor emitted by the embedding model. Returns: torch.Tensor: Sequence-aligned embedding tensor ready for attention. Example: StageNet categorical inputs may have shape ``[batch, seq_len, inner_len, emb]``. We sum over ``inner_len`` to obtain per-event representations. """ if x.dim() == 4: x = x.sum(dim=2) if x.dim() == 2: x = x.unsqueeze(1) # Truncate to max_seq_len to prevent quadratic memory spikes from # outlier-length sequences (attention is O(S^2)). if x.size(1) > self.max_seq_len: x = x[:, : self.max_seq_len, :] return x @staticmethod def _mask_from_embeddings(x: torch.Tensor) -> torch.Tensor: """Infer a boolean mask directly from embedded representations.""" mask = torch.any(torch.abs(x) > 0, dim=-1) if mask.dim() == 1: mask = mask.unsqueeze(1) invalid_rows = ~mask.any(dim=1) if invalid_rows.any(): mask[invalid_rows, 0] = True return mask.bool()
[docs] def forward_from_embedding( self, **kwargs: torch.Tensor | tuple[torch.Tensor, ...], ) -> Dict[str, torch.Tensor]: """Forward pass starting from feature embeddings. This method bypasses the embedding layers and processes pre-embedded features. This is useful for interpretability methods like Integrated Gradients that need to interpolate in embedding space. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. It is expected to contain the following semantic tensors: - "value": the embedded feature tensor of shape [batch, seq_len, embedding_dim] or [batch, embedding_dim]. - "mask" (optional): the mask tensor of shape [batch, seq_len]. If not in the processor schema, it can be provided as the last element of the feature tuple. If not provided, masks will be generated from the embedded values (non-zero entries are treated as valid). The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: loss: a scalar tensor representing the final loss. y_prob: a tensor of predicted probabilities. y_true: a tensor representing the true labels. logit: the raw logits before activation. embed: (if embed=True in kwargs) the patient embedding. """ # Support both the flag-based API and legacy kwarg-based API register_hook = self._attention_hooks_enabled patient_emb = [] for feature_key in self.feature_keys: processor = self.dataset.input_processors[feature_key] schema = processor.schema() feature = kwargs[feature_key] if isinstance(feature, torch.Tensor): feature = (feature,) value = feature[schema.index("value")] if "value" in schema else None mask = feature[schema.index("mask")] if "mask" in schema else None if len(feature) == len(schema) + 1 and mask is None: mask = feature[-1] if value is None: raise ValueError( f"Feature '{feature_key}' must contain 'value' " f"in the schema." ) else: value = value.to(self.device) value = self._pool_embedding(value) if mask is not None: mask = mask.to(self.device).bool() if mask.dim() == value.dim(): mask = mask.any(dim=-1) else: mask = self._mask_from_embeddings(value).to(self.device) _, cls_emb = self.transformer[feature_key]( value, mask, register_hook ) patient_emb.append(cls_emb) patient_emb = torch.cat(patient_emb, dim=1) logits = self.fc(patient_emb) y_prob = self.prepare_y_prob(logits) results = { "logit": logits, "y_prob": y_prob, } if self.label_key in kwargs: y_true = cast(torch.Tensor, kwargs[self.label_key]).to(self.device) loss = self.get_loss_function()(logits, y_true) results["loss"] = loss results["y_true"] = y_true if kwargs.get("embed", False): results["embed"] = patient_emb return results
[docs] def forward( self, **kwargs: torch.Tensor | tuple[torch.Tensor, ...], ) -> Dict[str, torch.Tensor]: """Forward propagation. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Feature keys should contain tensors or tuples of tensors following the processor schema. The label key should contain the true labels for loss computation. Returns: A dictionary with the following keys: loss: a scalar tensor representing the final loss. y_prob: a tensor of predicted probabilities. y_true: a tensor representing the true labels. logit: the raw logits before activation. embed: (if embed=True in kwargs) the patient embedding. """ for feature_key in self.feature_keys: feature = kwargs[feature_key] if isinstance(feature, torch.Tensor): feature = (feature,) schema = self.dataset.input_processors[feature_key].schema() value = feature[schema.index("value")] if "value" in schema else None mask = feature[schema.index("mask")] if "mask" in schema else None if value is None: raise ValueError( f"Feature '{feature_key}' must contain 'value' " f"in the schema." ) else: value = value.to(self.device) if mask is not None: mask = mask.to(self.device) value = self.embedding_model({feature_key: value}, masks={feature_key: mask})[feature_key] else: value = self.embedding_model({feature_key: value})[feature_key] i = schema.index("value") # Reconstruct tuple with embedded value # Note: we need to handle list/tuple conversion carefully # feature is a tuple. # Simple slice reconstruction kwargs[feature_key] = feature[:i] + (value,) + feature[i + 1:] return self.forward_from_embedding(**kwargs)
[docs] def get_embedding_model(self) -> nn.Module | None: """Get the embedding model. Returns: nn.Module: The embedding model used to embed raw features. """ return self.embedding_model
# ------------------------------------------------------------------ # CheferInterpretable interface # ------------------------------------------------------------------
[docs] def set_attention_hooks(self, enabled: bool) -> None: self._attention_hooks_enabled = enabled
[docs] def get_attention_layers( self, ) -> dict[str, list[tuple[torch.Tensor, torch.Tensor]]]: return { # type: ignore[return-value] key: [ ( cast(TransformerBlock, blk).attention.get_attn_map(), cast(TransformerBlock, blk).attention.get_attn_grad(), ) for blk in cast( TransformerLayer, self.transformer[key] ).transformer ] for key in self.feature_keys }
[docs] def get_relevance_tensor( self, R: dict[str, torch.Tensor], **data: torch.Tensor | tuple[torch.Tensor, ...], ) -> dict[str, torch.Tensor]: # CLS token is at index 0 for all feature keys result = {} for key, r in R.items(): # CLS token is at index 0; extract its attention row result[key] = r[:, 0] # [batch, attention_seq_len] return result
if __name__ == "__main__": from pyhealth.datasets import create_sample_dataset, get_dataloader samples = [ { "patient_id": "patient-0", "visit_id": "visit-0", "diagnoses": ["A", "B", "C"], "procedures": ["X", "Y"], "label": 1, }, { "patient_id": "patient-1", "visit_id": "visit-0", "diagnoses": ["D", "E"], "procedures": ["Z"], "label": 0, }, ] input_schema = { "diagnoses": "sequence", "procedures": "sequence", } output_schema = {"label": "binary"} dataset = create_sample_dataset( samples=samples, input_schema=input_schema, output_schema=output_schema, dataset_name="test", ) train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) model = Transformer(dataset=dataset, embedding_dim=64, heads=2, num_layers=2) data_batch = next(iter(train_loader)) result = model(**data_batch) print(result) result["loss"].backward()