pyhealth.interpret.methods.attention_rollout#
Overview#
Attention Rollout provides token-level relevance scores for Transformer models in PyHealth. It quantifies how attention propagates information across layers by composing the per-layer attention matrices (with a residual-connection correction), yielding a single importance score per input token (e.g. diagnosis codes, procedure codes, medications) for a given patient sample.
Unlike CheferRelevance, which is
gradient-weighted and class-specific, attention rollout is forward-pass
only, gradient-free, and class-agnostic: it explains how information
flows through the attention mechanism independent of any target class. It serves
as the standard baseline that gradient-based attention methods are compared
against, and complements Chefer rather than replacing it.
This method is particularly useful for:
Clinical decision support: Understanding which medical codes drove a particular prediction
Model debugging: Identifying whether the model attends to clinically meaningful features
Feature importance: Ranking tokens by how much attention flows to them
Trust and transparency: Providing interpretable, class-agnostic explanations for model predictions
The implementation follows the paper by Abnar & Zuidema (2020): “Quantifying Attention Flow in Transformers” (https://arxiv.org/abs/2005.00928).
Key Features#
Multi-modal support: Works with multiple feature types (conditions, procedures, drugs, labs, etc.)
Gradient-free: Computed from a single forward pass; no backward pass is used in the attribution math
Class-agnostic: Independent of the predicted/target class (
target_class_idxis accepted but ignored)Layer-wise composition: Composes per-layer attention as
rollout = Â_L @ ... @ Â_1with the residual correction = 0.5 * (A + I)Distribution over tokens: Because each
Âis row-stochastic, so is their product; per-token relevance sums to 1 (before the input-shape expansion)Model-agnostic by duck-typing: Works with any model exposing the attention-readout methods
set_attention_hooks,get_attention_layersandget_relevance_tensor(currentlyTransformerandStageAttentionNet), not just one named model
Usage Notes#
Batch size: For interpretability, use
batch_size=1to get per-sample explanations.Do not wrap in
torch.no_grad(): Although rollout is gradient-free in its math, the shared attention-readout plumbing registers a gradient hook on the attention tensors during the forward pass, so callingattribute(**batch)insidetorch.no_grad()raises aRuntimeError. Call it under the default (grad-enabled) context; no backward pass is performed.Model compatibility: Works with any model that exposes
set_attention_hooks,get_attention_layersandget_relevance_tensor— not restricted to the Transformer. Incompatible models raiseTypeErrorat construction.Class specification:
target_class_idxis accepted for API compatibility but ignored, since rollout is class-agnostic.
Quick Start#
from pyhealth.models import Transformer
from pyhealth.interpret.methods import AttentionRollout
from pyhealth.datasets import get_dataloader
# Assume you have a trained transformer model and dataset
model = Transformer(dataset=sample_dataset, ...)
# ... train the model ...
# Create interpretability object
rollout = AttentionRollout(model)
# Get a test sample (batch_size=1)
test_loader = get_dataloader(test_dataset, batch_size=1, shuffle=False)
batch = next(iter(test_loader))
# Compute attributions (target_class_idx is accepted but ignored)
scores = rollout.attribute(**batch)
# Analyze results
for feature_key, attribution in scores.items():
print(f"{feature_key}: {attribution.shape}")
top_tokens = attribution[0].topk(5).indices
print(f" Top 5 most relevant tokens: {top_tokens}")
API Reference#
- class pyhealth.interpret.methods.AttentionRollout(model, head_fusion='mean')[source]#
Bases:
BaseInterpreterAttention rollout for transformer interpretability.
Implements the canonical attention rollout method of Abnar & Zuidema, “Quantifying Attention Flow in Transformers” (2020), https://arxiv.org/abs/2005.00928.
Unlike
CheferRelevance, which is gradient-weighted and class-specific, rollout is forward-pass only, gradient-free, and class-agnostic: it quantifies how attention propagates information across layers, independent of any target class. It serves as the standard baseline that gradient-based attention methods are compared against.Note
“Gradient-free” refers to the attribution math: no backward pass is run and no gradients enter the rollout computation. It does not mean the call is safe inside
torch.no_grad(). The shared attention-readout plumbing registers a gradient hook on the attention tensors during the forward pass, so runningattribute(**batch)undertorch.no_grad()raises aRuntimeError. Call it under the default (grad-enabled) context.This interpreter works with any model that exposes the attention-readout methods
set_attention_hooks,get_attention_layers, andget_relevance_tensor(currentlyTransformerandStageAttentionNet). Compatibility is checked by duck-typing in__init__rather than by requiring a named interface, since these methods are general attention readout and not specific to any one method.The algorithm, per feature key:
Enable attention hooks via
model.set_attention_hooks(True)and run a single forward pass (no backward pass).Retrieve per-layer attention maps via
model.get_attention_layers(), discarding the gradient element of each(attn_map, attn_grad)pair.Fuse heads (mean) to get one
[batch, seq, seq]matrix per layer.Account for residual connections:
A_hat = 0.5 * (A + I).Compose layers by matrix product:
rollout = A_hat_L @ ... @ A_hat_1.Reduce to per-token scores via
model.get_relevance_tensor(), then expand to raw input value shapes.
Because each
A_hatis row-stochastic, so is their product; the per-token relevance therefore forms a distribution over tokens (sums to 1 before the input-shape expansion).- Parameters:
Example
>>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> from pyhealth.models import Transformer >>> from pyhealth.interpret.methods import AttentionRollout >>> >>> samples = [ ... { ... "patient_id": "p0", ... "visit_id": "v0", ... "conditions": ["A05B", "A05C", "A06A"], ... "procedures": ["P01", "P02"], ... "label": 1, ... }, ... { ... "patient_id": "p0", ... "visit_id": "v1", ... "conditions": ["A05B"], ... "procedures": ["P01"], ... "label": 0, ... }, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "procedures": "sequence"}, ... output_schema={"label": "binary"}, ... dataset_name="ehr_example", ... ) >>> model = Transformer(dataset=dataset) >>> # ... train the model ... >>> >>> interpreter = AttentionRollout(model) >>> batch = next(iter(get_dataloader(dataset, batch_size=2))) >>> >>> attributions = interpreter.attribute(**batch) >>> # Returns dict: {"conditions": tensor, "procedures": tensor} >>> print(attributions["conditions"].shape) # [batch, num_tokens] >>> >>> # target_class_idx is accepted but ignored (rollout is class-agnostic) >>> same = interpreter.attribute(target_class_idx=1, **batch)
- attribute(target_class_idx=None, **data)[source]#
Compute class-agnostic attention rollout attributions.
- Parameters:
- Returns:
- A dict keyed by the model’s feature keys.
Each value holds the rollout relevance for that feature — the CLS-token row of the composed attention-rollout matrix, reduced to one score per token by
model.get_relevance_tensor()and then expanded to the raw input value shape by_map_to_input_shapes. For flat sequence features this is[batch, num_tokens]; for nested sequences the per-visit score is replicated across the codes within each visit. Scores are non-negative and, before the input-shape expansion, sum to 1 across tokens (a consequence of composing row-stochastic matrices).
- Return type:
Dict[str, torch.Tensor]
Note
Do not call this method inside a
torch.no_grad()context. Even though rollout uses no gradients, enabling attention hooks registers a gradient hook during the forward pass, which requires grad-enabled tensors and otherwise raises aRuntimeError.