pyhealth.interpret.methods.chefer#
Overview#
The Chefer interpretability method provides token-level relevance scores for Transformer models in PyHealth. This approach is based on attention-based gradient propagation, which helps identify which input tokens (e.g., diagnosis codes, procedure codes, medications) most influenced the model’s prediction for a given patient sample.
This method is particularly useful for:
Clinical decision support: Understanding which medical codes drove a particular prediction
Model debugging: Identifying if the model is focusing on clinically meaningful features
Feature importance: Ranking tokens by their contribution to the prediction
Trust and transparency: Providing interpretable explanations for model predictions
The implementation follows the paper by Chefer et al. (2021): “Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers” (https://arxiv.org/abs/2103.15679).
Key Features#
Multi-modal support: Works with multiple feature types (conditions, procedures, drugs, labs, etc.)
Gradient-based: Uses attention gradients to compute relevance scores
Layer-wise propagation: Aggregates relevance across transformer layers
Non-negative scores: Returns clamped scores where higher values indicate greater relevance
Usage Notes#
Batch size: For interpretability, use batch_size=1 to get per-sample explanations
Gradients required: Do not use within
torch.no_grad()contextModel compatibility: Only works with PyHealth’s Transformer model
Class specification: You can specify a target class or use the predicted class
Quick Start#
from pyhealth.models import Transformer
from pyhealth.interpret.methods import CheferRelevance
from pyhealth.datasets import get_dataloader
# Assume you have a trained transformer model and dataset
model = Transformer(dataset=sample_dataset, ...)
# ... train the model ...
# Create interpretability object
relevance = CheferRelevance(model)
# Get a test sample (batch_size=1)
test_loader = get_dataloader(test_dataset, batch_size=1, shuffle=False)
batch = next(iter(test_loader))
# Compute relevance scores
scores = relevance.get_relevance_matrix(**batch)
# Analyze results
for feature_key, relevance_tensor in scores.items():
print(f"{feature_key}: {relevance_tensor.shape}")
top_tokens = relevance_tensor[0].topk(5).indices
print(f" Top 5 most relevant tokens: {top_tokens}")
API Reference#
- class pyhealth.interpret.methods.CheferRelevance(model)[source]#
Bases:
BaseInterpreterChefer’s gradient-weighted attention method for transformer interpretability.
This interpreter works with any model that implements the
CheferInterpretableinterface, which currently includes:The algorithm:
Enable attention hooks via
model.set_attention_hooks(True).Forward pass → capture attention maps and register gradient hooks.
Backward pass from a one-hot target class.
Retrieve
(attn_map, attn_grad)pairs viamodel.get_attention_layers().Propagate relevance:
R += clamp(attn * grad, min=0) @ R.Reduce
Rto per-token vectors viamodel.get_relevance_tensor().
Steps 1, 4 and 6 are delegated to the model through the
CheferInterpretableinterface, making this class fully model-agnostic.- Parameters:
model (BaseModel) – A trained PyHealth model that implements
CheferInterpretable.
Example
>>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> from pyhealth.models import Transformer >>> from pyhealth.interpret.methods import CheferRelevance >>> >>> samples = [ ... { ... "patient_id": "p0", ... "visit_id": "v0", ... "conditions": ["A05B", "A05C", "A06A"], ... "procedures": ["P01", "P02"], ... "label": 1, ... }, ... { ... "patient_id": "p0", ... "visit_id": "v1", ... "conditions": ["A05B"], ... "procedures": ["P01"], ... "label": 0, ... }, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "procedures": "sequence"}, ... output_schema={"label": "binary"}, ... dataset_name="ehr_example", ... ) >>> model = Transformer(dataset=dataset) >>> # ... train the model ... >>> >>> interpreter = CheferRelevance(model) >>> batch = next(iter(get_dataloader(dataset, batch_size=2))) >>> >>> # Default: attribute to predicted class >>> attributions = interpreter.attribute(**batch) >>> # Returns dict: {"conditions": tensor, "procedures": tensor} >>> print(attributions["conditions"].shape) # [batch, num_tokens] >>> >>> # Optional: attribute to a specific class (e.g., class 1) >>> attributions = interpreter.attribute(target_class_idx=1, **batch)
- attribute(target_class_idx=None, **data)[source]#
Compute relevance scores for each input token.
- Parameters:
target_class_idx (
Optional[int]) – Target class index to compute attribution for. If None (default), uses the argmax of model output. For binary classification (single logit output), this is a no-op because there is only one output.**data – Input data from dataloader batch containing feature keys and label key.
- Returns:
- Dictionary keyed by feature keys,
where each tensor has shape
[batch, seq_len]with per-token attribution scores.
- Return type:
Dict[str, torch.Tensor]
Helper Functions#
The module also includes internal helper functions for relevance computation:
- pyhealth.interpret.methods.chefer.apply_self_attention_rules(R_ss, cam_ss)[source]#
Apply Chefer’s self-attention rules for relevance propagation.
- Parameters:
R_ss – Relevance matrix [batch, seq_len, seq_len].
cam_ss – Attention weight matrix [batch, seq_len, seq_len].
- Returns:
Updated relevance matrix after propagating through attention layer.
- pyhealth.interpret.methods.chefer.avg_heads(cam, grad)[source]#
Average attention scores weighted by gradients across heads.
- Parameters:
cam – Attention weights [batch, heads, seq_len, seq_len] or [batch, seq_len, seq_len].
grad – Gradients w.r.t. attention weights. Same shape as cam.
- Returns:
Gradient-weighted attention [batch, seq_len, seq_len].