from __future__ import annotations
from typing import Dict, Optional
import torch
import torch.nn.functional as F
from pyhealth.models import BaseModel
from pyhealth.interpret.api import Interpretable
from .base_interpreter import BaseInterpreter
[docs]class IntegratedGradients(BaseInterpreter):
"""Integrated Gradients attribution method for PyHealth models.
This class implements the Integrated Gradients method for computing
feature attributions in neural networks. The method computes the
integral of gradients along a straight path from a baseline input
to the actual input.
The method is based on the paper:
Axiomatic Attribution for Deep Networks
Mukund Sundararajan, Ankur Taly, Qiqi Yan
ICML 2017
https://arxiv.org/abs/1703.01365
Integrated Gradients satisfies two fundamental axioms:
1. Sensitivity: If an input and a baseline differ in one feature
but have different predictions, then the differing feature
should be given non-zero attribution.
2. Implementation Invariance: The attributions are identical for
functionally equivalent networks.
Args:
model (BaseModel): A trained PyHealth model to interpret. Can be
any model that inherits from BaseModel (e.g., MLP, StageNet,
Transformer, RNN).
use_embeddings (bool): If True, compute gradients with respect to
embeddings rather than discrete input tokens. This is crucial
for models with discrete inputs (like ICD codes) where direct
interpolation of token indices is not meaningful. The model
must support returning embeddings via an 'embed' parameter.
Default is True.
Note:
**Why use_embeddings=True is recommended:**
When working with discrete features (e.g., ICD diagnosis codes,
procedure codes), Integrated Gradients needs to interpolate between
a baseline and the actual input. However, interpolating discrete
token indices directly creates invalid intermediate values:
- Input code index: 245 (e.g., "Diabetes Type 2")
- Baseline index: 0 (padding token)
- Interpolation creates: 0 -> 61.25 -> 122.5 -> 183.75 -> 245
Fractional indices like 61.25 cannot be looked up in an embedding
table and cause "index out of bounds" errors.
With ``use_embeddings=True``, the method:
1. Embeds both baseline and input tokens into continuous vectors
2. Interpolates in the embedding space (which is valid)
3. Computes gradients with respect to these embeddings
4. Maps attributions back to the original input tokens
This makes IG compatible with models like StageNet, Transformer,
RNN, and MLP that process discrete medical codes.
Set ``use_embeddings=False`` only when all inputs are continuous
(e.g., vital signs, lab values) and no embedding layers are used.
Examples:
>>> import torch
>>> from pyhealth.datasets import (
... SampleDataset, split_by_patient, get_dataloader
... )
>>> from pyhealth.models import MLP
>>> from pyhealth.interpret.methods import IntegratedGradients
>>> from pyhealth.trainer import Trainer
>>>
>>> # Define sample data
>>> samples = [
... {
... "patient_id": "patient-0",
... "visit_id": "visit-0",
... "conditions": ["cond-33", "cond-86", "cond-80"],
... "procedures": [1.0, 2.0, 3.5, 4.0],
... "label": 1,
... },
... {
... "patient_id": "patient-1",
... "visit_id": "visit-1",
... "conditions": ["cond-55", "cond-12"],
... "procedures": [5.0, 2.0, 3.5, 4.0],
... "label": 0,
... },
... # ... more samples
... ]
>>>
>>> # Create dataset with schema
>>> input_schema = {
... "conditions": "sequence",
... "procedures": "tensor"
... }
>>> output_schema = {"label": "binary"}
>>>
>>> dataset = SampleDataset(
... samples=samples,
... input_schema=input_schema,
... output_schema=output_schema,
... dataset_name="example"
... )
>>>
>>> # Initialize MLP model
>>> model = MLP(
... dataset=dataset,
... embedding_dim=128,
... hidden_dim=128,
... dropout=0.3
... )
>>>
>>> # Split data and create dataloaders
>>> train_data, val_data, test_data = split_by_patient(
... dataset, [0.7, 0.15, 0.15]
... )
>>> train_loader = get_dataloader(
... train_data, batch_size=32, shuffle=True
... )
>>> val_loader = get_dataloader(
... val_data, batch_size=32, shuffle=False
... )
>>> test_loader = get_dataloader(
... test_data, batch_size=1, shuffle=False
... )
>>>
>>> # Train model
>>> trainer = Trainer(model=model, device="cuda:0")
>>> trainer.train(
... train_dataloader=train_loader,
... val_dataloader=val_loader,
... epochs=10,
... monitor="roc_auc"
... )
>>>
>>> # Compute attributions for test samples
>>> ig = IntegratedGradients(model)
>>> data_batch = next(iter(test_loader))
>>>
>>> # Option 1: Use zero baseline (default)
>>> attributions = ig.attribute(**data_batch, steps=5)
>>> print(attributions)
{'conditions': tensor([[0.1234, 0.5678, 0.9012]], device='cuda:0'),
'procedures': tensor([[0.2345, 0.6789, 0.0123, 0.4567]])}
>>>
>>> # Option 2: Specify target class explicitly
>>> data_batch['target_class_idx'] = 1
>>> attributions = ig.attribute(**data_batch, steps=5)
>>>
>>> # Option 3: Use custom baseline
>>> custom_baseline = {
... 'conditions': torch.zeros_like(data_batch['conditions']),
... 'procedures': torch.ones_like(data_batch['procedures']) * 0.5
... }
>>> attributions = ig.attribute(
... **data_batch, baseline=custom_baseline, steps=5
... )
"""
def __init__(self, model: BaseModel, use_embeddings: bool = True, steps: int = 50):
"""Initialize IntegratedGradients interpreter.
Args:
model: A trained PyHealth model to interpret.
use_embeddings: If True, compute gradients with respect to
embeddings rather than discrete input tokens. Default True.
This is required for models with discrete inputs like ICD
codes. Set to False only for fully continuous input models.
When True, the model must implement forward_from_embedding()
and have an embedding model accessible via get_embedding_model().
steps: Default number of interpolation steps for Riemann
approximation of the path integral. Default is 50.
Can be overridden in attribute() calls. More steps lead to
better approximation but slower computation.
Raises:
AssertionError: If use_embeddings=True but model does not
implement forward_from_embedding() method.
"""
super().__init__(model)
if not isinstance(model, Interpretable):
raise ValueError("Model must implement Interpretable interface")
self.model = model
self.use_embeddings = use_embeddings
self.steps = steps
[docs] def attribute(
self,
baseline: Optional[Dict[str, torch.Tensor]] = None,
steps: Optional[int] = None,
target_class_idx: Optional[int] = None,
**kwargs: torch.Tensor | tuple[torch.Tensor, ...],
) -> Dict[str, torch.Tensor]:
"""Compute Integrated Gradients attributions for input features.
This method computes the path integral of gradients from a
baseline input to the actual input. The integral is approximated
using Riemann sum with the specified number of steps.
Args:
baseline: Baseline input for integration. Can be:
- None: Uses UNK-token baseline for discrete features or
small near-zero baseline for continuous features (default)
- Dict[str, torch.Tensor]: Custom baseline for each feature
steps: Number of steps to use in the Riemann approximation of
the integral. If None, uses self.steps (set during
initialization). More steps lead to better approximation but
slower computation.
target_class_idx: Target class index for attribution.
For binary classification (single logit output), this is
a no-op because there is only one output. For multi-class
or multi-label, specifies which class to explain. If None,
uses the argmax of model output.
**kwargs: Input data dictionary from a dataloader batch
containing:
- Feature keys (e.g., 'conditions', 'procedures'):
Input tensors or tuples of tensors for each modality
- 'label' (optional): Ground truth label tensor
- Other metadata keys are ignored
Returns:
Dict[str, torch.Tensor]: Dictionary mapping each feature key
to its attribution tensor. Each tensor has the same shape
as the input tensor, with values indicating the
contribution of each input element to the model's
prediction. Positive values indicate features that
increase the prediction score, while negative values
indicate features that decrease it.
Note:
- This method requires gradients, so the model should not be
in torch.no_grad() context.
- For better interpretability, use batch_size=1 or analyze
samples individually.
- The sum of attributions across all features approximates
the difference between the model's prediction for the input
and the baseline (completeness axiom).
Examples:
>>> from pyhealth.interpret.methods import IntegratedGradients
>>>
>>> # Assuming you have a trained model and test data
>>> ig = IntegratedGradients(trained_model)
>>> test_batch = next(iter(test_loader))
>>>
>>> # Compute attributions with default settings
>>> attributions = ig.attribute(**test_batch)
>>> print(f"Feature attributions: {attributions.keys()}")
>>> print(f"Conditions: {attributions['conditions'].shape}")
>>>
>>> # Use more steps for better approximation
>>> attributions = ig.attribute(**test_batch, steps=100)
>>>
>>> # Compute attributions for specific class
>>> attributions = ig.attribute(
... **test_batch, target_class_idx=0, steps=50
... )
>>>
>>> # Use custom baseline
>>> custom_baseline = {
... 'conditions': torch.zeros_like(test_batch['conditions']),
... 'procedures': torch.zeros_like(test_batch['procedures'])
... }
>>> attributions = ig.attribute(
... **test_batch, baseline=custom_baseline, steps=50
... )
>>>
>>> # Analyze which features are most important
>>> condition_attr = attributions['conditions'][0]
>>> top_k = torch.topk(torch.abs(condition_attr), k=5)
>>> print(f"Most important features: {top_k.indices}")
"""
# Use instance default if steps not specified
if steps is None:
steps = self.steps
device = next(self.model.parameters()).device
# Filter kwargs to only include model feature keys and ensure they are tuples
inputs = {
k: (v,) if isinstance(v, torch.Tensor) else v
for k, v in kwargs.items()
if k in self.model.feature_keys
}
# Disassemble inputs to get values and masks
values: dict[str, torch.Tensor] = {}
masks: dict[str, torch.Tensor] = {}
for k, v in inputs.items():
schema = self.model.dataset.input_processors[k].schema()
values[k] = v[schema.index("value")]
if "mask" in schema:
masks[k] = v[schema.index("mask")]
else:
val = v[schema.index("value")]
processor = self.model.dataset.input_processors[k]
if processor.is_token():
masks[k] = (val != 0).int()
else:
# For continuous features, check whether the entire
# feature vector at each timestep is zero (padding)
# rather than per-element, so valid 0.0 values are
# not masked out.
if val.dim() >= 3:
masks[k] = (val.abs().sum(dim=-1) != 0).int()
else:
masks[k] = (val != 0).int()
# Append input masks to inputs for models that expect them
for k, v in inputs.items():
if "mask" not in self.model.dataset.input_processors[k].schema():
inputs[k] = (*v, masks[k])
# Determine target class from original input
with torch.no_grad():
base_logits = self.model.forward(**inputs)["logit"]
target_indices = self._resolve_target_indices(base_logits, target_class_idx)
# Generate baselines
if baseline is None:
baselines = self._generate_baseline(
values, use_embeddings=self.use_embeddings
)
else:
baselines = {
k: v.to(device)
for k, v in baseline.items()
if k in self.model.feature_keys
}
# Save raw shapes before embedding for later mapping
shapes = {k: v.shape for k, v in values.items()}
# Split features by type using is_token():
# - Token features (discrete): embed before interpolation, since
# interpolating raw indices is meaningless. Gradients are computed
# w.r.t. embeddings, then summed over the embedding dim.
# - Continuous features: keep raw for interpolation so each raw
# dimension gets its own attribution. The model's forward() handles
# embedding internally.
if self.use_embeddings:
embedding_model = self.model.get_embedding_model()
assert embedding_model is not None, (
"Model must have an embedding model for embedding-based "
"Integrated Gradients."
)
token_keys = {
k for k in values
if self.model.dataset.input_processors[k].is_token()
}
if token_keys:
# Embed token values
token_values = {k: values[k] for k in token_keys}
embedded_tokens = embedding_model(token_values)
for k in token_keys:
values[k] = embedded_tokens[k]
# Embed token baselines so they live in the same space
token_baselines = {k: baselines[k] for k in token_keys if k in baselines}
if token_baselines:
embedded_baselines = embedding_model(token_baselines)
for k in token_baselines:
baselines[k] = embedded_baselines[k]
# Compute integrated gradients
attributions = self._integrated_gradients(
inputs=inputs,
xs=values,
bs=baselines,
steps=steps,
target_indices=target_indices,
)
return self._map_to_input_shapes(attributions, shapes)
# ------------------------------------------------------------------
# Core IG computation
# ------------------------------------------------------------------
def _integrated_gradients(
self,
inputs: Dict[str, tuple[torch.Tensor, ...]],
xs: Dict[str, torch.Tensor],
bs: Dict[str, torch.Tensor],
steps: int,
target_indices: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Compute integrated gradients via Riemann sum approximation.
For each interpolation step alpha in [0, 1]:
1. Creates interpolated values: baseline + alpha * (input - baseline)
2. Inserts them into the input tuples via the processor schema
3. Runs forward pass (forward_from_embedding or forward)
4. Computes gradients w.r.t. the interpolated values
5. Accumulates gradients using a running sum (memory efficient)
After all steps, computes the final attribution as:
(input - baseline) * average_gradient
Args:
inputs: Full input tuples keyed by feature name.
xs: Input values (embedded if use_embeddings=True).
bs: Baseline values (embedded if use_embeddings=True).
steps: Number of interpolation steps.
target_indices: [batch] tensor of target class indices.
Returns:
Dictionary mapping feature keys to attribution tensors.
"""
keys = sorted(xs.keys())
# Determine which keys are token (already embedded) vs continuous (raw)
token_keys = set()
continuous_keys = set()
if self.use_embeddings:
for k in keys:
if self.model.dataset.input_processors[k].is_token():
token_keys.add(k)
else:
continuous_keys.add(k)
# If not using embeddings, all features are treated as continuous/raw
# Use running sum instead of storing all gradients (memory efficient)
avg_gradients = {key: torch.zeros_like(xs[key]) for key in keys}
for step_idx in range(steps + 1):
alpha = step_idx / steps
# Create interpolated values with gradients enabled
interpolated: dict[str, torch.Tensor] = {}
for key in keys:
interp = bs[key] + alpha * (xs[key] - bs[key])
interp = interp.detach().requires_grad_(True)
# CRITICAL: retain_grad() needed for non-leaf tensors
interp.retain_grad()
interpolated[key] = interp
# Insert interpolated values back into input tuples
forward_inputs = inputs.copy()
for k in forward_inputs.keys():
schema = self.model.dataset.input_processors[k].schema()
val_idx = schema.index("value")
forward_inputs[k] = (
*forward_inputs[k][:val_idx],
interpolated[k],
*forward_inputs[k][val_idx + 1:],
)
# Forward pass: use forward_from_embedding for token features
# (already embedded), but continuous features still need embedding
# inside the model. We always use forward_from_embedding and let
# it handle both embedded and raw values.
if self.use_embeddings:
# For continuous features, embed them before forward_from_embedding
if continuous_keys:
embedding_model = self.model.get_embedding_model()
assert embedding_model is not None, (
"Model must have an embedding model for embedding-based "
"Integrated Gradients."
)
continuous_to_embed = {
k: interpolated[k] for k in continuous_keys
}
embedded_continuous = embedding_model(continuous_to_embed)
for k in continuous_keys:
schema = self.model.dataset.input_processors[k].schema()
val_idx = schema.index("value")
forward_inputs[k] = (
*forward_inputs[k][:val_idx],
embedded_continuous[k],
*forward_inputs[k][val_idx + 1:],
)
output = self.model.forward_from_embedding(**forward_inputs)
else:
output = self.model.forward(**forward_inputs)
logits = output["logit"]
# Compute target output and backward pass
target_output = self._compute_target_output(logits, target_indices)
self.model.zero_grad()
target_output.backward(retain_graph=True)
# Accumulate gradients using running sum
for key in keys:
emb = interpolated[key]
if emb.grad is not None:
avg_gradients[key] += emb.grad.detach()
# Average the accumulated gradients
for key in keys:
avg_gradients[key] /= steps + 1
# Compute final attributions: (input - baseline) * avg_gradient
attributions: dict[str, torch.Tensor] = {}
for key in keys:
delta = xs[key] - bs[key]
attr = delta * avg_gradients[key]
# When using embeddings, sum over the embedding dimension
# to collapse from (batch, ..., emb_dim) to (batch, ...)
# Only for token features that were embedded before interpolation
if self.use_embeddings and key in token_keys and attr.dim() >= 3:
attr = attr.sum(dim=-1)
attributions[key] = attr
return attributions
# ------------------------------------------------------------------
# Target output computation
# ------------------------------------------------------------------
def _compute_target_output(
self,
logits: torch.Tensor,
target_indices: torch.Tensor,
) -> torch.Tensor:
"""Compute scalar target output for backpropagation.
Selects the target-class logit for each sample and sums over
the batch to produce a single differentiable scalar.
Args:
logits: Model output logits, shape [batch, num_classes].
target_indices: [batch] tensor of target class indices.
Returns:
Scalar tensor for backpropagation.
"""
return logits.gather(
1, target_indices.unsqueeze(1)
).squeeze(1).sum()
# ------------------------------------------------------------------
# Baseline generation
# ------------------------------------------------------------------
def _generate_baseline(
self,
values: Dict[str, torch.Tensor],
use_embeddings: bool = False,
) -> Dict[str, torch.Tensor]:
"""Generate raw baselines for IG computation.
Creates reference samples representing the "absence" of features.
The strategy depends on the feature type:
- Discrete (token) features: UNK token index (will be embedded
later in ``attribute()`` alongside the values)
- Continuous features: small near-zero neutral values
Args:
values: Dictionary of raw input value tensors (before embedding).
use_embeddings: If True, generate baselines suitable for
embedding-based IG.
Returns:
Dictionary mapping feature names to baseline tensors in raw
(pre-embedding) space. Embedding of token baselines is handled
by the caller (``attribute()``).
"""
baselines: dict[str, torch.Tensor] = {}
for k, v in values.items():
processor = self.model.dataset.input_processors[k]
if use_embeddings and processor.is_token():
# Token features: UNK token index as baseline
baseline = torch.ones_like(v)
else:
# Continuous features (or non-embedding mode): near-zero baseline
baseline = torch.zeros_like(v) + 1e-2
baselines[k] = baseline
return baselines
# ------------------------------------------------------------------
# Utility helpers
# ------------------------------------------------------------------
@staticmethod
def _map_to_input_shapes(
attr_values: Dict[str, torch.Tensor],
input_shapes: dict,
) -> Dict[str, torch.Tensor]:
"""Map attributions back to original input tensor shapes.
For embedding-based attributions, the embedding dimension has
already been summed out. This method handles any remaining
shape mismatches (e.g., expanding scalar attributions to match
multi-dimensional inputs).
Args:
attr_values: Dictionary of attribution tensors.
input_shapes: Dictionary of original input shapes.
Returns:
Dictionary of attributions reshaped to match original inputs.
"""
mapped: dict[str, torch.Tensor] = {}
for key, values in attr_values.items():
if key not in input_shapes:
mapped[key] = values
continue
orig_shape = input_shapes[key]
# If shapes already match, no adjustment needed
if values.shape == orig_shape:
mapped[key] = values
continue
# Expand dimensions to match original input
reshaped = values
while len(reshaped.shape) < len(orig_shape):
reshaped = reshaped.unsqueeze(-1)
if reshaped.shape != orig_shape:
reshaped = reshaped.expand(orig_shape)
mapped[key] = reshaped
return mapped