"""Attention rollout for transformer interpretability.
This module implements the canonical attention rollout method, a
forward-pass-only, gradient-free, class-agnostic attention-flow baseline.
It complements the gradient-weighted, class-specific
:class:`~pyhealth.interpret.methods.CheferRelevance`.
Paper:
Abnar, Samira, and Willem Zuidema.
"Quantifying Attention Flow in Transformers."
Proceedings of the 58th Annual Meeting of the Association for
Computational Linguistics (ACL), 2020.
https://arxiv.org/abs/2005.00928
"""
from typing import Dict, Optional
import torch
from pyhealth.models.base_model import BaseModel
from .base_interpreter import BaseInterpreter
[docs]class AttentionRollout(BaseInterpreter):
"""Attention rollout for transformer interpretability.
Implements the canonical attention rollout method of Abnar & Zuidema,
"Quantifying Attention Flow in Transformers" (2020),
https://arxiv.org/abs/2005.00928.
Unlike :class:`~pyhealth.interpret.methods.CheferRelevance`, which is
gradient-weighted and class-specific, rollout is **forward-pass only**,
**gradient-free**, and **class-agnostic**: it quantifies how attention
propagates information across layers, independent of any target class.
It serves as the standard baseline that gradient-based attention methods
are compared against.
.. note::
"Gradient-free" refers to the attribution **math**: no backward pass
is run and no gradients enter the rollout computation. It does **not**
mean the call is safe inside ``torch.no_grad()``. The shared
attention-readout plumbing registers a gradient hook on the attention
tensors during the forward pass, so running ``attribute(**batch)``
under ``torch.no_grad()`` raises a ``RuntimeError``. Call it under the
default (grad-enabled) context.
This interpreter works with any model that exposes the attention-readout
methods ``set_attention_hooks``, ``get_attention_layers``, and
``get_relevance_tensor`` (currently :class:`~pyhealth.models.Transformer`
and :class:`~pyhealth.models.StageAttentionNet`). Compatibility is checked
by duck-typing in ``__init__`` rather than by requiring a named interface,
since these methods are general attention readout and not specific to any
one method.
The algorithm, per feature key:
1. Enable attention hooks via ``model.set_attention_hooks(True)`` and run a
single forward pass (no backward pass).
2. Retrieve per-layer attention maps via ``model.get_attention_layers()``,
discarding the gradient element of each ``(attn_map, attn_grad)`` pair.
3. Fuse heads (mean) to get one ``[batch, seq, seq]`` matrix per layer.
4. Account for residual connections: ``A_hat = 0.5 * (A + I)``.
5. Compose layers by matrix product: ``rollout = A_hat_L @ ... @ A_hat_1``.
6. Reduce to per-token scores via ``model.get_relevance_tensor()``, then
expand to raw input value shapes.
Because each ``A_hat`` is row-stochastic, so is their product; the
per-token relevance therefore forms a distribution over tokens (sums to 1
before the input-shape expansion).
Args:
model (BaseModel): A trained PyHealth model exposing the attention-
readout methods listed above.
head_fusion (str): How to combine attention heads into a single matrix
per layer. Currently only ``"mean"`` is supported (the canonical
choice from the paper). Defaults to ``"mean"``.
Example:
>>> from pyhealth.datasets import create_sample_dataset, get_dataloader
>>> from pyhealth.models import Transformer
>>> from pyhealth.interpret.methods import AttentionRollout
>>>
>>> samples = [
... {
... "patient_id": "p0",
... "visit_id": "v0",
... "conditions": ["A05B", "A05C", "A06A"],
... "procedures": ["P01", "P02"],
... "label": 1,
... },
... {
... "patient_id": "p0",
... "visit_id": "v1",
... "conditions": ["A05B"],
... "procedures": ["P01"],
... "label": 0,
... },
... ]
>>> dataset = create_sample_dataset(
... samples=samples,
... input_schema={"conditions": "sequence", "procedures": "sequence"},
... output_schema={"label": "binary"},
... dataset_name="ehr_example",
... )
>>> model = Transformer(dataset=dataset)
>>> # ... train the model ...
>>>
>>> interpreter = AttentionRollout(model)
>>> batch = next(iter(get_dataloader(dataset, batch_size=2)))
>>>
>>> attributions = interpreter.attribute(**batch)
>>> # Returns dict: {"conditions": tensor, "procedures": tensor}
>>> print(attributions["conditions"].shape) # [batch, num_tokens]
>>>
>>> # target_class_idx is accepted but ignored (rollout is class-agnostic)
>>> same = interpreter.attribute(target_class_idx=1, **batch)
"""
def __init__(self, model: BaseModel, head_fusion: str = "mean"):
if head_fusion != "mean":
raise ValueError(
f"Unsupported head_fusion='{head_fusion}'. "
"Currently supported values: mean."
)
required_methods = [
"set_attention_hooks",
"get_attention_layers",
"get_relevance_tensor",
]
missing_methods = [m for m in required_methods if not hasattr(model, m)]
if missing_methods:
raise TypeError(
"AttentionRollout requires a model that exposes the attention "
"interpretability methods: "
f"{', '.join(required_methods)}. "
f"Missing: {', '.join(missing_methods)}."
)
super().__init__(model)
self.head_fusion = head_fusion
[docs] def attribute(
self,
target_class_idx: Optional[int] = None,
**data,
) -> Dict[str, torch.Tensor]:
"""Compute class-agnostic attention rollout attributions.
Args:
target_class_idx: Accepted for API compatibility with class-specific
interpreters. Attention rollout is class-agnostic, so this argument
is ignored.
**data: Batch input passed directly to the model.
Returns:
Dict[str, torch.Tensor]: A dict keyed by the model's feature keys.
Each value holds the rollout relevance for that feature — the
CLS-token row of the composed attention-rollout matrix, reduced
to one score per token by ``model.get_relevance_tensor()`` and
then expanded to the raw input value shape by
``_map_to_input_shapes``. For flat sequence features this is
``[batch, num_tokens]``; for nested sequences the per-visit
score is replicated across the codes within each visit.
Scores are non-negative and, before the input-shape expansion,
sum to 1 across tokens (a consequence of composing
row-stochastic matrices).
Note:
Do not call this method inside a ``torch.no_grad()`` context. Even
though rollout uses no gradients, enabling attention hooks registers
a gradient hook during the forward pass, which requires grad-enabled
tensors and otherwise raises a ``RuntimeError``.
"""
self.model.set_attention_hooks(True)
try:
self.model(**data)
finally:
self.model.set_attention_hooks(False)
attention_layers = self.model.get_attention_layers()
R = {}
for feature_key, layers in attention_layers.items():
rollout = None
for attn_map, _ in layers:
if attn_map is None:
raise RuntimeError(
"AttentionRollout expected attention maps to be captured "
f"for feature '{feature_key}', but found None."
)
attn = self._fuse_heads(attn_map)
attn = self._add_residual(attn)
if rollout is None:
batch_size, seq_len, _ = attn.shape
rollout = torch.eye(
seq_len,
device=attn.device,
dtype=attn.dtype,
)
rollout = rollout.unsqueeze(0).expand(
batch_size,
seq_len,
seq_len,
)
rollout = torch.bmm(attn, rollout)
if rollout is None:
raise RuntimeError(
"AttentionRollout expected at least one attention layer "
f"for feature '{feature_key}', but found none."
)
R[feature_key] = rollout
attributions = self.model.get_relevance_tensor(R, **data)
return self._map_to_input_shapes(attributions, data)
def _fuse_heads(self, attn_map: torch.Tensor) -> torch.Tensor:
"""Fuse attention heads from [batch, heads, seq, seq] to [batch, seq, seq]."""
if (self.head_fusion == "mean"):
return attn_map.mean(dim=1)
def _map_to_input_shapes(
self,
attributions: Dict[str, torch.Tensor],
data: dict,
) -> Dict[str, torch.Tensor]:
"""Expand attributions to match raw input value shapes.
For nested sequences the attention operates on a pooled
(visit-level) sequence, but downstream consumers (e.g. ablation
metrics) expect attributions to match the raw input value shape.
Per-visit relevance scores are replicated across all codes
within each visit.
Args:
attributions: Per-feature attribution tensors returned by
``model.get_relevance_tensor()``.
data: Original ``**data`` kwargs from the dataloader batch.
Returns:
Attributions expanded to raw input value shapes where needed.
"""
result: Dict[str, torch.Tensor] = {}
for key, attr in attributions.items():
feature = data.get(key)
if feature is not None:
if isinstance(feature, torch.Tensor):
val = feature
else:
schema = self.model.dataset.input_processors[key].schema()
val = (
feature[schema.index("value")]
if "value" in schema
else None
)
if val is not None and val.dim() > attr.dim():
for _ in range(val.dim() - attr.dim()):
attr = attr.unsqueeze(-1)
attr = attr.expand_as(val)
result[key] = attr
return result
@staticmethod
def _add_residual(attn: torch.Tensor) -> torch.Tensor:
"""
Add canonical rollout residual connection: 0.5 * (A + I).
0.5 * (A + I) stays row-stochastic only because A is (soft-max ouput).
"""
batch, seq_len, _ = attn.shape
identity = torch.eye(
seq_len,
device=attn.device,
dtype=attn.dtype,
).unsqueeze(0)
return 0.5 * (attn + identity)