"""Chefer's gradient-weighted attention relevance propagation.
This module implements the Chefer et al. relevance propagation method for
explaining transformer-family model predictions. It relies on the
:class:`~pyhealth.interpret.api.CheferInterpretable` interface — any model
that implements that interface is automatically supported.
Paper:
Chefer, Hila, Shir Gur, and Lior Wolf.
"Generic Attention-model Explainability for Interpreting Bi-Modal and
Encoder-Decoder Transformers."
Proceedings of the IEEE/CVF International Conference on Computer Vision
(ICCV), 2021.
"""
from typing import Dict, Optional, cast
import torch
import torch.nn.functional as F
from pyhealth.interpret.api import CheferInterpretable
from pyhealth.models.base_model import BaseModel
from pyhealth.interpret.api import CheferInterpretable
from .base_interpreter import BaseInterpreter
# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------
[docs]def apply_self_attention_rules(R_ss, cam_ss):
"""Apply Chefer's self-attention rules for relevance propagation.
Args:
R_ss: Relevance matrix [batch, seq_len, seq_len].
cam_ss: Attention weight matrix [batch, seq_len, seq_len].
Returns:
Updated relevance matrix after propagating through attention layer.
"""
return torch.matmul(cam_ss, R_ss)
[docs]def avg_heads(cam, grad):
"""Average attention scores weighted by gradients across heads.
Args:
cam: Attention weights [batch, heads, seq_len, seq_len] or [batch, seq_len, seq_len].
grad: Gradients w.r.t. attention weights. Same shape as cam.
Returns:
Gradient-weighted attention [batch, seq_len, seq_len].
"""
if len(cam.size()) < 4 and len(grad.size()) < 4:
return (grad * cam).clamp(min=0)
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=1)
return cam.clone()
# ---------------------------------------------------------------------------
# Main interpreter
# ---------------------------------------------------------------------------
[docs]class CheferRelevance(BaseInterpreter):
"""Chefer's gradient-weighted attention method for transformer interpretability.
This interpreter works with **any** model that implements the
:class:`~pyhealth.interpret.api.CheferInterpretable` interface, which
currently includes:
* :class:`~pyhealth.models.Transformer`
* :class:`~pyhealth.models.StageAttentionNet`
The algorithm:
1. Enable attention hooks via ``model.set_attention_hooks(True)``.
2. Forward pass → capture attention maps and register gradient hooks.
3. Backward pass from a one-hot target class.
4. Retrieve ``(attn_map, attn_grad)`` pairs via ``model.get_attention_layers()``.
5. Propagate relevance: ``R += clamp(attn * grad, min=0) @ R``.
6. Reduce ``R`` to per-token vectors via ``model.get_relevance_tensor()``.
Steps 1, 4 and 6 are delegated to the model through the
``CheferInterpretable`` interface, making this class fully
model-agnostic.
Args:
model (BaseModel): A trained PyHealth model that implements
:class:`~pyhealth.interpret.api.CheferInterpretable`.
Example:
>>> from pyhealth.datasets import create_sample_dataset, get_dataloader
>>> from pyhealth.models import Transformer
>>> from pyhealth.interpret.methods import CheferRelevance
>>>
>>> samples = [
... {
... "patient_id": "p0",
... "visit_id": "v0",
... "conditions": ["A05B", "A05C", "A06A"],
... "procedures": ["P01", "P02"],
... "label": 1,
... },
... {
... "patient_id": "p0",
... "visit_id": "v1",
... "conditions": ["A05B"],
... "procedures": ["P01"],
... "label": 0,
... },
... ]
>>> dataset = create_sample_dataset(
... samples=samples,
... input_schema={"conditions": "sequence", "procedures": "sequence"},
... output_schema={"label": "binary"},
... dataset_name="ehr_example",
... )
>>> model = Transformer(dataset=dataset)
>>> # ... train the model ...
>>>
>>> interpreter = CheferRelevance(model)
>>> batch = next(iter(get_dataloader(dataset, batch_size=2)))
>>>
>>> # Default: attribute to predicted class
>>> attributions = interpreter.attribute(**batch)
>>> # Returns dict: {"conditions": tensor, "procedures": tensor}
>>> print(attributions["conditions"].shape) # [batch, num_tokens]
>>>
>>> # Optional: attribute to a specific class (e.g., class 1)
>>> attributions = interpreter.attribute(target_class_idx=1, **batch)
"""
def __init__(self, model: BaseModel):
super().__init__(model)
if not isinstance(model, CheferInterpretable):
raise ValueError("Model must implement CheferInterpretable interface")
self.model = model
[docs] def attribute(
self,
target_class_idx: Optional[int] = None,
**data,
) -> Dict[str, torch.Tensor]:
"""Compute relevance scores for each input token.
Args:
target_class_idx: Target class index to compute attribution for.
If None (default), uses the argmax of model output.
For binary classification (single logit output), this is
a no-op because there is only one output.
**data: Input data from dataloader batch containing feature
keys and label key.
Returns:
Dict[str, torch.Tensor]: Dictionary keyed by feature keys,
where each tensor has shape ``[batch, seq_len]`` with
per-token attribution scores.
"""
# --- 1. Forward with attention hooks enabled ---
self.model.set_attention_hooks(True)
try:
logits = self.model(**data)["logit"]
finally:
self.model.set_attention_hooks(False)
# --- 2. Backward from target class ---
target_indices = self._resolve_target_indices(logits, target_class_idx)
one_hot = F.one_hot(
target_indices.detach().clone(), logits.size(1)
).float()
one_hot = one_hot.requires_grad_(True)
scalar = torch.sum(one_hot.to(logits.device) * logits)
self.model.zero_grad()
scalar.backward(retain_graph=True)
# --- 3. Retrieve (attn_map, attn_grad) pairs per feature key ---
attention_layers = self.model.get_attention_layers()
batch_size = logits.shape[0]
device = logits.device
# --- 4. Relevance propagation per feature key ---
R_dict: dict[str, torch.Tensor] = {}
for key, layers in attention_layers.items():
num_tokens = layers[0][0].shape[-1]
R = (
torch.eye(num_tokens, device=device)
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
for cam, grad in layers:
cam = avg_heads(cam, grad)
R = R + apply_self_attention_rules(R, cam).detach()
R_dict[key] = R
# --- 5. Reduce R matrices to per-token vectors ---
attributions = self.model.get_relevance_tensor(R_dict, **data)
# --- 6. Expand to match raw input shapes (nested sequences) ---
return self._map_to_input_shapes(attributions, data)
# ------------------------------------------------------------------
# Shape mapping
# ------------------------------------------------------------------
def _map_to_input_shapes(
self,
attributions: Dict[str, torch.Tensor],
data: dict,
) -> Dict[str, torch.Tensor]:
"""Expand attributions to match raw input value shapes.
For nested sequences the attention operates on a pooled
(visit-level) sequence, but downstream consumers (e.g. ablation
metrics) expect attributions to match the raw input value shape.
Per-visit relevance scores are replicated across all codes
within each visit.
Args:
attributions: Per-feature attribution tensors returned by
``model.get_relevance_tensor()``.
data: Original ``**data`` kwargs from the dataloader batch.
Returns:
Attributions expanded to raw input value shapes where needed.
"""
result: Dict[str, torch.Tensor] = {}
for key, attr in attributions.items():
feature = data.get(key)
if feature is not None:
if isinstance(feature, torch.Tensor):
val = feature
else:
schema = self.model.dataset.input_processors[key].schema()
val = (
feature[schema.index("value")]
if "value" in schema
else None
)
if val is not None and val.dim() > attr.dim():
for _ in range(val.dim() - attr.dim()):
attr = attr.unsqueeze(-1)
attr = attr.expand_as(val)
result[key] = attr
return result
# ------------------------------------------------------------------
# Backward compatibility aliases
# ------------------------------------------------------------------
[docs] def get_relevance_matrix(self, **data):
"""Alias for attribute(). Deprecated."""
return self.attribute(**data)
# ======================================================================
# LEGACY REFERENCE IMPLEMENTATIONS
# ======================================================================
# The functions below are the original model-specific implementations
# that existed before the CheferInterpretable API was introduced. They
# are kept here ONLY as a reference for future developers and are NOT
# called by any production code. They may be removed in a future
# release.
#
# For ViT models, _reference_attribute_vit is the only implementation
# until ViT models implement CheferInterpretable.
# ======================================================================
def _reference_attribute_transformer(
model,
class_index=None,
**data,
) -> Dict[str, torch.Tensor]:
"""[REFERENCE ONLY] Original Transformer-specific Chefer attribution.
This was the body of ``CheferRelevance._attribute_transformer()``
before the CheferInterpretable API was introduced. It accesses
model internals (``model.transformer[key].transformer``) directly.
"""
data["register_hook"] = True
logits = model(**data)["logit"]
if class_index is None:
class_index = torch.argmax(logits, dim=-1)
if isinstance(class_index, torch.Tensor):
one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float()
else:
one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float()
one_hot = one_hot.requires_grad_(True)
one_hot = torch.sum(one_hot.to(logits.device) * logits)
model.zero_grad()
one_hot.backward(retain_graph=True)
feature_keys = model.feature_keys
num_tokens = {}
for key in feature_keys:
feature_transformer = model.transformer[key].transformer
for block in feature_transformer:
num_tokens[key] = block.attention.get_attn_map().shape[-1]
batch_size = logits.shape[0]
attn = {}
for key in feature_keys:
R = (
torch.eye(num_tokens[key])
.unsqueeze(0)
.repeat(batch_size, 1, 1)
.to(logits.device)
)
for blk in model.transformer[key].transformer:
grad = blk.attention.get_attn_grad()
cam = blk.attention.get_attn_map()
cam = avg_heads(cam, grad)
R += apply_self_attention_rules(R, cam).detach()
attn[key] = R[:, 0]
return attn
def _reference_attribute_stageattn(
model,
class_index=None,
**data,
) -> Dict[str, torch.Tensor]:
"""[REFERENCE ONLY] Original StageAttentionNet-specific Chefer attribution.
This was the body of ``CheferRelevance._attribute_stageattn()``
before the CheferInterpretable API was introduced. It accesses
model internals (``model.stagenet[key]``, ``model.embedding_model``)
directly.
"""
data["register_attn_hook"] = True
logits = model(**data)["logit"]
if class_index is None:
class_index = torch.argmax(logits, dim=-1)
if isinstance(class_index, torch.Tensor):
one_hot = F.one_hot(class_index.detach().clone(), logits.size()[1]).float()
else:
one_hot = F.one_hot(torch.tensor(class_index), logits.size()[1]).float()
one_hot = one_hot.requires_grad_(True)
one_hot = torch.sum(one_hot.to(logits.device) * logits)
model.zero_grad()
one_hot.backward(retain_graph=True)
batch_size = logits.shape[0]
feature_keys = model.feature_keys
attn = {}
for key in feature_keys:
layer = model.stagenet[key]
cam = layer.get_attn_map()
grad = layer.get_attn_grad()
num_tokens = cam.shape[-1]
R = (
torch.eye(num_tokens)
.unsqueeze(0)
.repeat(batch_size, 1, 1)
.to(logits.device)
)
cam = avg_heads(cam, grad)
R += apply_self_attention_rules(R, cam).detach()
feature = data[key]
if isinstance(feature, tuple) and len(feature) == 2:
_, x_val = feature
else:
x_val = feature
embedded = model.embedding_model({key: x_val})
emb = embedded[key]
if emb.dim() == 4:
emb = emb.sum(dim=2)
mask = (emb.sum(dim=-1) != 0).long().to(logits.device)
last_idx = mask.sum(dim=1) - 1
attn[key] = R[torch.arange(batch_size, device=logits.device), last_idx]
return attn
def _reference_attribute_vit(
model,
interpolate: bool = True,
class_index=None,
**data,
) -> Dict[str, torch.Tensor]:
"""[REFERENCE ONLY] Original ViT-specific Chefer attribution.
ViT models do not yet implement CheferInterpretable. This code
shows the ViT-specific flow that will be needed when ViT support is
added to the unified API.
"""
feature_key = model.feature_keys[0]
x = data.get(feature_key)
if x is None:
raise ValueError(
f"Expected feature key '{feature_key}' in data. "
f"Available keys: {list(data.keys())}"
)
x = x.to(model.device)
input_size = x.shape[-1]
model.zero_grad()
logits, attention_maps = model.forward_with_attention(x, register_hook=True)
target_class = class_index
if target_class is None:
target_class = logits.argmax(dim=-1)
one_hot = torch.zeros_like(logits)
if isinstance(target_class, int):
one_hot[:, target_class] = 1
else:
if target_class.dim() == 0:
target_class = target_class.unsqueeze(0)
one_hot.scatter_(1, target_class.unsqueeze(1), 1)
one_hot = one_hot.requires_grad_(True)
(logits * one_hot).sum().backward(retain_graph=True)
attention_gradients = model.get_attention_gradients()
batch_size = attention_maps[0].shape[0]
num_tokens = attention_maps[0].shape[-1]
device = attention_maps[0].device
R = torch.eye(num_tokens, device=device)
R = R.unsqueeze(0).expand(batch_size, -1, -1).clone()
for attn, grad in zip(attention_maps, attention_gradients):
cam = avg_heads(attn, grad)
R = R + apply_self_attention_rules(R.detach(), cam.detach())
patches_attr = R[:, 0, 1:]
h_patches, w_patches = model.get_num_patches(input_size)
attr_map = patches_attr.reshape(batch_size, 1, h_patches, w_patches)
if interpolate:
attr_map = F.interpolate(
attr_map,
size=(input_size, input_size),
mode="bilinear",
align_corners=False,
)
return {feature_key: attr_map}