pyhealth.interpret.methods.attention_rollout#

Overview#

Attention Rollout provides token-level relevance scores for Transformer models in PyHealth. It quantifies how attention propagates information across layers by composing the per-layer attention matrices (with a residual-connection correction), yielding a single importance score per input token (e.g. diagnosis codes, procedure codes, medications) for a given patient sample.

Unlike CheferRelevance, which is gradient-weighted and class-specific, attention rollout is forward-pass only, gradient-free, and class-agnostic: it explains how information flows through the attention mechanism independent of any target class. It serves as the standard baseline that gradient-based attention methods are compared against, and complements Chefer rather than replacing it.

This method is particularly useful for:

  • Clinical decision support: Understanding which medical codes drove a particular prediction

  • Model debugging: Identifying whether the model attends to clinically meaningful features

  • Feature importance: Ranking tokens by how much attention flows to them

  • Trust and transparency: Providing interpretable, class-agnostic explanations for model predictions

The implementation follows the paper by Abnar & Zuidema (2020): “Quantifying Attention Flow in Transformers” (https://arxiv.org/abs/2005.00928).

Key Features#

  • Multi-modal support: Works with multiple feature types (conditions, procedures, drugs, labs, etc.)

  • Gradient-free: Computed from a single forward pass; no backward pass is used in the attribution math

  • Class-agnostic: Independent of the predicted/target class (target_class_idx is accepted but ignored)

  • Layer-wise composition: Composes per-layer attention as rollout = Â_L @ ... @ Â_1 with the residual correction  = 0.5 * (A + I)

  • Distribution over tokens: Because each  is row-stochastic, so is their product; per-token relevance sums to 1 (before the input-shape expansion)

  • Model-agnostic by duck-typing: Works with any model exposing the attention-readout methods set_attention_hooks, get_attention_layers and get_relevance_tensor (currently Transformer and StageAttentionNet), not just one named model

Usage Notes#

  1. Batch size: For interpretability, use batch_size=1 to get per-sample explanations.

  2. Do not wrap in torch.no_grad(): Although rollout is gradient-free in its math, the shared attention-readout plumbing registers a gradient hook on the attention tensors during the forward pass, so calling attribute(**batch) inside torch.no_grad() raises a RuntimeError. Call it under the default (grad-enabled) context; no backward pass is performed.

  3. Model compatibility: Works with any model that exposes set_attention_hooks, get_attention_layers and get_relevance_tensor — not restricted to the Transformer. Incompatible models raise TypeError at construction.

  4. Class specification: target_class_idx is accepted for API compatibility but ignored, since rollout is class-agnostic.

Quick Start#

from pyhealth.models import Transformer
from pyhealth.interpret.methods import AttentionRollout
from pyhealth.datasets import get_dataloader

# Assume you have a trained transformer model and dataset
model = Transformer(dataset=sample_dataset, ...)
# ... train the model ...

# Create interpretability object
rollout = AttentionRollout(model)

# Get a test sample (batch_size=1)
test_loader = get_dataloader(test_dataset, batch_size=1, shuffle=False)
batch = next(iter(test_loader))

# Compute attributions (target_class_idx is accepted but ignored)
scores = rollout.attribute(**batch)

# Analyze results
for feature_key, attribution in scores.items():
    print(f"{feature_key}: {attribution.shape}")
    top_tokens = attribution[0].topk(5).indices
    print(f"  Top 5 most relevant tokens: {top_tokens}")

API Reference#

class pyhealth.interpret.methods.AttentionRollout(model, head_fusion='mean')[source]#

Bases: BaseInterpreter

Attention rollout for transformer interpretability.

Implements the canonical attention rollout method of Abnar & Zuidema, “Quantifying Attention Flow in Transformers” (2020), https://arxiv.org/abs/2005.00928.

Unlike CheferRelevance, which is gradient-weighted and class-specific, rollout is forward-pass only, gradient-free, and class-agnostic: it quantifies how attention propagates information across layers, independent of any target class. It serves as the standard baseline that gradient-based attention methods are compared against.

Note

“Gradient-free” refers to the attribution math: no backward pass is run and no gradients enter the rollout computation. It does not mean the call is safe inside torch.no_grad(). The shared attention-readout plumbing registers a gradient hook on the attention tensors during the forward pass, so running attribute(**batch) under torch.no_grad() raises a RuntimeError. Call it under the default (grad-enabled) context.

This interpreter works with any model that exposes the attention-readout methods set_attention_hooks, get_attention_layers, and get_relevance_tensor (currently Transformer and StageAttentionNet). Compatibility is checked by duck-typing in __init__ rather than by requiring a named interface, since these methods are general attention readout and not specific to any one method.

The algorithm, per feature key:

  1. Enable attention hooks via model.set_attention_hooks(True) and run a single forward pass (no backward pass).

  2. Retrieve per-layer attention maps via model.get_attention_layers(), discarding the gradient element of each (attn_map, attn_grad) pair.

  3. Fuse heads (mean) to get one [batch, seq, seq] matrix per layer.

  4. Account for residual connections: A_hat = 0.5 * (A + I).

  5. Compose layers by matrix product: rollout = A_hat_L @ ... @ A_hat_1.

  6. Reduce to per-token scores via model.get_relevance_tensor(), then expand to raw input value shapes.

Because each A_hat is row-stochastic, so is their product; the per-token relevance therefore forms a distribution over tokens (sums to 1 before the input-shape expansion).

Parameters:
  • model (BaseModel) – A trained PyHealth model exposing the attention- readout methods listed above.

  • head_fusion (str) – How to combine attention heads into a single matrix per layer. Currently only "mean" is supported (the canonical choice from the paper). Defaults to "mean".

Example

>>> from pyhealth.datasets import create_sample_dataset, get_dataloader
>>> from pyhealth.models import Transformer
>>> from pyhealth.interpret.methods import AttentionRollout
>>>
>>> samples = [
...     {
...         "patient_id": "p0",
...         "visit_id": "v0",
...         "conditions": ["A05B", "A05C", "A06A"],
...         "procedures": ["P01", "P02"],
...         "label": 1,
...     },
...     {
...         "patient_id": "p0",
...         "visit_id": "v1",
...         "conditions": ["A05B"],
...         "procedures": ["P01"],
...         "label": 0,
...     },
... ]
>>> dataset = create_sample_dataset(
...     samples=samples,
...     input_schema={"conditions": "sequence", "procedures": "sequence"},
...     output_schema={"label": "binary"},
...     dataset_name="ehr_example",
... )
>>> model = Transformer(dataset=dataset)
>>> # ... train the model ...
>>>
>>> interpreter = AttentionRollout(model)
>>> batch = next(iter(get_dataloader(dataset, batch_size=2)))
>>>
>>> attributions = interpreter.attribute(**batch)
>>> # Returns dict: {"conditions": tensor, "procedures": tensor}
>>> print(attributions["conditions"].shape)  # [batch, num_tokens]
>>>
>>> # target_class_idx is accepted but ignored (rollout is class-agnostic)
>>> same = interpreter.attribute(target_class_idx=1, **batch)
attribute(target_class_idx=None, **data)[source]#

Compute class-agnostic attention rollout attributions.

Parameters:
  • target_class_idx (Optional[int]) – Accepted for API compatibility with class-specific interpreters. Attention rollout is class-agnostic, so this argument is ignored.

  • **data – Batch input passed directly to the model.

Returns:

A dict keyed by the model’s feature keys.

Each value holds the rollout relevance for that feature — the CLS-token row of the composed attention-rollout matrix, reduced to one score per token by model.get_relevance_tensor() and then expanded to the raw input value shape by _map_to_input_shapes. For flat sequence features this is [batch, num_tokens]; for nested sequences the per-visit score is replicated across the codes within each visit. Scores are non-negative and, before the input-shape expansion, sum to 1 across tokens (a consequence of composing row-stochastic matrices).

Return type:

Dict[str, torch.Tensor]

Note

Do not call this method inside a torch.no_grad() context. Even though rollout uses no gradients, enabling attention hooks registers a gradient hook during the forward pass, which requires grad-enabled tensors and otherwise raises a RuntimeError.