pyhealth.interpret.methods.deeplift#

Overview#

pyhealth.interpret.methods.deeplift.DeepLift provides difference-from-baseline attributions for PyHealth models. Consult the class docstring for detailed guidance, usage notes, and examples. A full workflow is demonstrated in examples/deeplift_stagenet_mimic4.py.

API Reference#

class pyhealth.interpret.methods.deeplift.DeepLift(model, use_embeddings=True)[source]#

Bases: BaseInterpreter

DeepLIFT attribution for PyHealth models.

Paper: Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. Learning Important Features through Propagating Activation Differences. ICML 2017.

DeepLIFT propagates difference-from-baseline activations using Rescale multipliers so that feature attributions sum to the change in model output. The implementation injects secant slopes for supported activations (ReLU, Sigmoid, Tanh) via module swapping to mirror the original algorithm while falling back to autograd gradients for unsupported operations.

This method is particularly useful for:
  • EHR feature importance: highlight influential visits, codes, or labs when auditing StageNet-style models.

  • Contrastive explanations: compare predictions against a clinically meaningful baseline patient trajectory.

  • Mixed-input attribution: handle discrete embedding channels and continuous features in a unified call.

  • Model debugging: diagnose activation saturation and verify the completeness axiom.

Key Features:
  • Dual operating modes for embedding-based or continuous inputs.

  • Automatic activation module swapping for DeepLIFT Rescale rule.

  • Completeness enforcement ensuring sum(attribution) ~= f(x) - f(x0).

  • Batch-friendly API accepting trainer-style dictionaries with tuple-based inputs following processor schemas.

  • Target control via target_class_idx to explain any desired logit.

  • Mixed token/continuous feature support using is_token() processor introspection.

Usage Notes:
  1. Choose a baseline dictionary that reflects a neutral clinical state when zeros are not meaningful.

  2. Move inputs, baselines, and the model to the same device before calling attribute.

  3. Keep use_embeddings=True for token indices; set it to False to attribute continuous tensors directly.

  4. Call model.eval() so stochastic layers remain deterministic during paired forward passes.

Parameters:
  • model (BaseModel) – A BaseModel instance exposing either forward_from_embedding() (for discrete inputs) or the standard forward() used by PyHealth trainers.

  • use_embeddings (bool) – Whether to operate in embedding space. Set to True (default) for tokenized inputs or False to attribute continuous tensors directly.

Examples

>>> import torch
>>> from pyhealth.datasets import create_sample_dataset, get_dataloader
>>> from pyhealth.interpret.methods.deeplift import DeepLift
>>> from pyhealth.models import MLP
>>>
>>> samples = [
...     {"patient_id": "p0", "visit_id": "v0",
...      "conditions": ["cond-33", "cond-86", "cond-80"],
...      "procedures": [1.0, 2.0, 3.5, 4.0], "label": 1},
...     {"patient_id": "p1", "visit_id": "v1",
...      "conditions": ["cond-55", "cond-12"],
...      "procedures": [5.0, 2.0, 3.5, 4.0], "label": 0},
... ]
>>> dataset = create_sample_dataset(
...     samples=samples,
...     input_schema={"conditions": "sequence", "procedures": "tensor"},
...     output_schema={"label": "binary"},
... )
>>> model = MLP(dataset=dataset, embedding_dim=32, hidden_dim=32)
>>> model.eval()
>>> test_loader = get_dataloader(dataset, batch_size=1, shuffle=False)
>>> deeplift = DeepLift(model, use_embeddings=True)
>>>
>>> batch = next(iter(test_loader))
>>> attributions = deeplift.attribute(**batch)
>>> print({k: v.shape for k, v in attributions.items()})
Algorithm Details:
  1. Run a baseline forward pass while caching activations for supported nonlinearities.

  2. Replay the actual inputs with Rescale hooks that substitute secant slopes for local derivatives.

  3. Backpropagate the target logit so gradients equal DeepLIFT multipliers.

  4. Multiply input differences by the propagated multipliers and enforce completeness.

References

[1] Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. Learning

Important Features through Propagating Activation Differences. Proceedings of the 34th International Conference on Machine Learning (ICML), 2017. https://proceedings.mlr.press/v70/shrikumar17a.html

attribute(baseline=None, target_class_idx=None, **kwargs)[source]#

Compute DeepLIFT attributions for a single batch.

The method follows Algorithm 2 of the DeepLIFT paper: two forward passes (baseline then actual) are executed under the hook context so that backward propagation yields multipliers equal to the Rescale rule.

Parameters:
  • baseline (Optional[Dict[str, Tensor]]) – Optional dictionary providing reference inputs per feature key. If omitted, UNK tokens are used for discrete features and near-zero values for continuous features.

  • target_class_idx (Optional[int]) – Optional class index to explain. None defaults to the model prediction.

  • **kwargs (Tensor | tuple[Tensor, ...]) –

    Input data dictionary from a dataloader batch containing: - Feature keys (e.g., ‘conditions’, ‘procedures’):

    Input tensors or tuples of tensors for each modality

    • ’label’ (optional): Ground truth label tensor

    • Other metadata keys are ignored

Return type:

Dict[str, Tensor]

Returns:

Dict[str, torch.Tensor] mapping each feature key to attribution tensors shaped like the original inputs. All outputs satisfy the completeness property sum_i attribution_i f(x) - f(x₀).