pyhealth.interpret.methods.chefer#

Overview#

The Chefer interpretability method provides token-level relevance scores for Transformer models in PyHealth. This approach is based on attention-based gradient propagation, which helps identify which input tokens (e.g., diagnosis codes, procedure codes, medications) most influenced the model’s prediction for a given patient sample.

This method is particularly useful for:

  • Clinical decision support: Understanding which medical codes drove a particular prediction

  • Model debugging: Identifying if the model is focusing on clinically meaningful features

  • Feature importance: Ranking tokens by their contribution to the prediction

  • Trust and transparency: Providing interpretable explanations for model predictions

The implementation follows the paper by Chefer et al. (2021): “Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers” (https://arxiv.org/abs/2103.15679).

Key Features#

  • Multi-modal support: Works with multiple feature types (conditions, procedures, drugs, labs, etc.)

  • Gradient-based: Uses attention gradients to compute relevance scores

  • Layer-wise propagation: Aggregates relevance across transformer layers

  • Non-negative scores: Returns clamped scores where higher values indicate greater relevance

Usage Notes#

  1. Batch size: For interpretability, use batch_size=1 to get per-sample explanations

  2. Gradients required: Do not use within torch.no_grad() context

  3. Model compatibility: Only works with PyHealth’s Transformer model

  4. Class specification: You can specify a target class or use the predicted class

Quick Start#

from pyhealth.models import Transformer
from pyhealth.interpret.methods import CheferRelevance
from pyhealth.datasets import get_dataloader

# Assume you have a trained transformer model and dataset
model = Transformer(dataset=sample_dataset, ...)
# ... train the model ...

# Create interpretability object
relevance = CheferRelevance(model)

# Get a test sample (batch_size=1)
test_loader = get_dataloader(test_dataset, batch_size=1, shuffle=False)
batch = next(iter(test_loader))

# Compute relevance scores
scores = relevance.get_relevance_matrix(**batch)

# Analyze results
for feature_key, relevance_tensor in scores.items():
    print(f"{feature_key}: {relevance_tensor.shape}")
    top_tokens = relevance_tensor[0].topk(5).indices
    print(f"  Top 5 most relevant tokens: {top_tokens}")

API Reference#

class pyhealth.interpret.methods.CheferRelevance(model)[source]#

Bases: BaseInterpreter

Chefer’s gradient-weighted attention method for transformer interpretability.

This interpreter works with any model that implements the CheferInterpretable interface, which currently includes:

The algorithm:

  1. Enable attention hooks via model.set_attention_hooks(True).

  2. Forward pass → capture attention maps and register gradient hooks.

  3. Backward pass from a one-hot target class.

  4. Retrieve (attn_map, attn_grad) pairs via model.get_attention_layers().

  5. Propagate relevance: R += clamp(attn * grad, min=0) @ R.

  6. Reduce R to per-token vectors via model.get_relevance_tensor().

Steps 1, 4 and 6 are delegated to the model through the CheferInterpretable interface, making this class fully model-agnostic.

Parameters:

model (BaseModel) – A trained PyHealth model that implements CheferInterpretable.

Example

>>> from pyhealth.datasets import create_sample_dataset, get_dataloader
>>> from pyhealth.models import Transformer
>>> from pyhealth.interpret.methods import CheferRelevance
>>>
>>> samples = [
...     {
...         "patient_id": "p0",
...         "visit_id": "v0",
...         "conditions": ["A05B", "A05C", "A06A"],
...         "procedures": ["P01", "P02"],
...         "label": 1,
...     },
...     {
...         "patient_id": "p0",
...         "visit_id": "v1",
...         "conditions": ["A05B"],
...         "procedures": ["P01"],
...         "label": 0,
...     },
... ]
>>> dataset = create_sample_dataset(
...     samples=samples,
...     input_schema={"conditions": "sequence", "procedures": "sequence"},
...     output_schema={"label": "binary"},
...     dataset_name="ehr_example",
... )
>>> model = Transformer(dataset=dataset)
>>> # ... train the model ...
>>>
>>> interpreter = CheferRelevance(model)
>>> batch = next(iter(get_dataloader(dataset, batch_size=2)))
>>>
>>> # Default: attribute to predicted class
>>> attributions = interpreter.attribute(**batch)
>>> # Returns dict: {"conditions": tensor, "procedures": tensor}
>>> print(attributions["conditions"].shape)  # [batch, num_tokens]
>>>
>>> # Optional: attribute to a specific class (e.g., class 1)
>>> attributions = interpreter.attribute(target_class_idx=1, **batch)
attribute(target_class_idx=None, **data)[source]#

Compute relevance scores for each input token.

Parameters:
  • target_class_idx (Optional[int]) – Target class index to compute attribution for. If None (default), uses the argmax of model output. For binary classification (single logit output), this is a no-op because there is only one output.

  • **data – Input data from dataloader batch containing feature keys and label key.

Returns:

Dictionary keyed by feature keys,

where each tensor has shape [batch, seq_len] with per-token attribution scores.

Return type:

Dict[str, torch.Tensor]

get_relevance_matrix(**data)[source]#

Alias for attribute(). Deprecated.

Helper Functions#

The module also includes internal helper functions for relevance computation:

pyhealth.interpret.methods.chefer.apply_self_attention_rules(R_ss, cam_ss)[source]#

Apply Chefer’s self-attention rules for relevance propagation.

Parameters:
  • R_ss – Relevance matrix [batch, seq_len, seq_len].

  • cam_ss – Attention weight matrix [batch, seq_len, seq_len].

Returns:

Updated relevance matrix after propagating through attention layer.

pyhealth.interpret.methods.chefer.avg_heads(cam, grad)[source]#

Average attention scores weighted by gradients across heads.

Parameters:
  • cam – Attention weights [batch, heads, seq_len, seq_len] or [batch, seq_len, seq_len].

  • grad – Gradients w.r.t. attention weights. Same shape as cam.

Returns:

Gradient-weighted attention [batch, seq_len, seq_len].