pyhealth.interpret.methods.integrated_gradients#

Overview#

The Integrated Gradients method computes feature attributions for PyHealth models by integrating gradients along a path from a baseline to the actual input. This helps identify which features (e.g., diagnosis codes, lab values) most influenced a model’s prediction.

For a complete working example, see: examples/integrated_gradients_mortality_mimic4_stagenet.py

API Reference#

class pyhealth.interpret.methods.IntegratedGradients(model, use_embeddings=True, steps=50)[source]#

Bases: BaseInterpreter

Integrated Gradients attribution method for PyHealth models.

This class implements the Integrated Gradients method for computing feature attributions in neural networks. The method computes the integral of gradients along a straight path from a baseline input to the actual input.

The method is based on the paper:

Axiomatic Attribution for Deep Networks Mukund Sundararajan, Ankur Taly, Qiqi Yan ICML 2017 https://arxiv.org/abs/1703.01365

Integrated Gradients satisfies two fundamental axioms:
  1. Sensitivity: If an input and a baseline differ in one feature but have different predictions, then the differing feature should be given non-zero attribution.

  2. Implementation Invariance: The attributions are identical for functionally equivalent networks.

Parameters:
  • model (BaseModel) – A trained PyHealth model to interpret. Can be any model that inherits from BaseModel (e.g., MLP, StageNet, Transformer, RNN).

  • use_embeddings (bool) – If True, compute gradients with respect to embeddings rather than discrete input tokens. This is crucial for models with discrete inputs (like ICD codes) where direct interpolation of token indices is not meaningful. The model must support returning embeddings via an ‘embed’ parameter. Default is True.

Note

Why use_embeddings=True is recommended:

When working with discrete features (e.g., ICD diagnosis codes, procedure codes), Integrated Gradients needs to interpolate between a baseline and the actual input. However, interpolating discrete token indices directly creates invalid intermediate values:

  • Input code index: 245 (e.g., “Diabetes Type 2”)

  • Baseline index: 0 (padding token)

  • Interpolation creates: 0 -> 61.25 -> 122.5 -> 183.75 -> 245

Fractional indices like 61.25 cannot be looked up in an embedding table and cause “index out of bounds” errors.

With use_embeddings=True, the method: 1. Embeds both baseline and input tokens into continuous vectors 2. Interpolates in the embedding space (which is valid) 3. Computes gradients with respect to these embeddings 4. Maps attributions back to the original input tokens

This makes IG compatible with models like StageNet, Transformer, RNN, and MLP that process discrete medical codes.

Set use_embeddings=False only when all inputs are continuous (e.g., vital signs, lab values) and no embedding layers are used.

Examples

>>> import torch
>>> from pyhealth.datasets import (
...     SampleDataset, split_by_patient, get_dataloader
... )
>>> from pyhealth.models import MLP
>>> from pyhealth.interpret.methods import IntegratedGradients
>>> from pyhealth.trainer import Trainer
>>>
>>> # Define sample data
>>> samples = [
...     {
...         "patient_id": "patient-0",
...         "visit_id": "visit-0",
...         "conditions": ["cond-33", "cond-86", "cond-80"],
...         "procedures": [1.0, 2.0, 3.5, 4.0],
...         "label": 1,
...     },
...     {
...         "patient_id": "patient-1",
...         "visit_id": "visit-1",
...         "conditions": ["cond-55", "cond-12"],
...         "procedures": [5.0, 2.0, 3.5, 4.0],
...         "label": 0,
...     },
...     # ... more samples
... ]
>>>
>>> # Create dataset with schema
>>> input_schema = {
...     "conditions": "sequence",
...     "procedures": "tensor"
... }
>>> output_schema = {"label": "binary"}
>>>
>>> dataset = SampleDataset(
...     samples=samples,
...     input_schema=input_schema,
...     output_schema=output_schema,
...     dataset_name="example"
... )
>>>
>>> # Initialize MLP model
>>> model = MLP(
...     dataset=dataset,
...     embedding_dim=128,
...     hidden_dim=128,
...     dropout=0.3
... )
>>>
>>> # Split data and create dataloaders
>>> train_data, val_data, test_data = split_by_patient(
...     dataset, [0.7, 0.15, 0.15]
... )
>>> train_loader = get_dataloader(
...     train_data, batch_size=32, shuffle=True
... )
>>> val_loader = get_dataloader(
...     val_data, batch_size=32, shuffle=False
... )
>>> test_loader = get_dataloader(
...     test_data, batch_size=1, shuffle=False
... )
>>>
>>> # Train model
>>> trainer = Trainer(model=model, device="cuda:0")
>>> trainer.train(
...     train_dataloader=train_loader,
...     val_dataloader=val_loader,
...     epochs=10,
...     monitor="roc_auc"
... )
>>>
>>> # Compute attributions for test samples
>>> ig = IntegratedGradients(model)
>>> data_batch = next(iter(test_loader))
>>>
>>> # Option 1: Use zero baseline (default)
>>> attributions = ig.attribute(**data_batch, steps=5)
>>> print(attributions)
{'conditions': tensor([[0.1234, 0.5678, 0.9012]], device='cuda:0'),
 'procedures': tensor([[0.2345, 0.6789, 0.0123, 0.4567]])}
>>>
>>> # Option 2: Specify target class explicitly
>>> data_batch['target_class_idx'] = 1
>>> attributions = ig.attribute(**data_batch, steps=5)
>>>
>>> # Option 3: Use custom baseline
>>> custom_baseline = {
...     'conditions': torch.zeros_like(data_batch['conditions']),
...     'procedures': torch.ones_like(data_batch['procedures']) * 0.5
... }
>>> attributions = ig.attribute(
...     **data_batch, baseline=custom_baseline, steps=5
... )
attribute(baseline=None, steps=None, target_class_idx=None, **kwargs)[source]#

Compute Integrated Gradients attributions for input features.

This method computes the path integral of gradients from a baseline input to the actual input. The integral is approximated using Riemann sum with the specified number of steps.

Parameters:
  • baseline (Optional[Dict[str, Tensor]]) –

    Baseline input for integration. Can be: - None: Uses UNK-token baseline for discrete features or

    small near-zero baseline for continuous features (default)

    • Dict[str, torch.Tensor]: Custom baseline for each feature

  • steps (Optional[int]) – Number of steps to use in the Riemann approximation of the integral. If None, uses self.steps (set during initialization). More steps lead to better approximation but slower computation.

  • target_class_idx (Optional[int]) – Target class index for attribution. For binary classification (single logit output), this is a no-op because there is only one output. For multi-class or multi-label, specifies which class to explain. If None, uses the argmax of model output.

  • **kwargs (Tensor | tuple[Tensor, ...]) –

    Input data dictionary from a dataloader batch containing: - Feature keys (e.g., ‘conditions’, ‘procedures’):

    Input tensors or tuples of tensors for each modality

    • ’label’ (optional): Ground truth label tensor

    • Other metadata keys are ignored

Returns:

Dictionary mapping each feature key

to its attribution tensor. Each tensor has the same shape as the input tensor, with values indicating the contribution of each input element to the model’s prediction. Positive values indicate features that increase the prediction score, while negative values indicate features that decrease it.

Return type:

Dict[str, torch.Tensor]

Note

  • This method requires gradients, so the model should not be in torch.no_grad() context.

  • For better interpretability, use batch_size=1 or analyze samples individually.

  • The sum of attributions across all features approximates the difference between the model’s prediction for the input and the baseline (completeness axiom).

Examples

>>> from pyhealth.interpret.methods import IntegratedGradients
>>>
>>> # Assuming you have a trained model and test data
>>> ig = IntegratedGradients(trained_model)
>>> test_batch = next(iter(test_loader))
>>>
>>> # Compute attributions with default settings
>>> attributions = ig.attribute(**test_batch)
>>> print(f"Feature attributions: {attributions.keys()}")
>>> print(f"Conditions: {attributions['conditions'].shape}")
>>>
>>> # Use more steps for better approximation
>>> attributions = ig.attribute(**test_batch, steps=100)
>>>
>>> # Compute attributions for specific class
>>> attributions = ig.attribute(
...     **test_batch, target_class_idx=0, steps=50
... )
>>>
>>> # Use custom baseline
>>> custom_baseline = {
...     'conditions': torch.zeros_like(test_batch['conditions']),
...     'procedures': torch.zeros_like(test_batch['procedures'])
... }
>>> attributions = ig.attribute(
...     **test_batch, baseline=custom_baseline, steps=50
... )
>>>
>>> # Analyze which features are most important
>>> condition_attr = attributions['conditions'][0]
>>> top_k = torch.topk(torch.abs(condition_attr), k=5)
>>> print(f"Most important features: {top_k.indices}")