pyhealth.interpret.methods.deeplift#
Overview#
pyhealth.interpret.methods.deeplift.DeepLift provides difference-from-baseline
attributions for PyHealth models. Consult the class docstring for detailed guidance,
usage notes, and examples. A full workflow is demonstrated in
examples/deeplift_stagenet_mimic4.py.
API Reference#
- class pyhealth.interpret.methods.deeplift.DeepLift(model, use_embeddings=True)[source]#
Bases:
BaseInterpreterDeepLIFT attribution for PyHealth models.
Paper: Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. Learning Important Features through Propagating Activation Differences. ICML 2017.
DeepLIFT propagates difference-from-baseline activations using Rescale multipliers so that feature attributions sum to the change in model output. The implementation injects secant slopes for supported activations (ReLU, Sigmoid, Tanh) via module swapping to mirror the original algorithm while falling back to autograd gradients for unsupported operations.
- This method is particularly useful for:
EHR feature importance: highlight influential visits, codes, or labs when auditing StageNet-style models.
Contrastive explanations: compare predictions against a clinically meaningful baseline patient trajectory.
Mixed-input attribution: handle discrete embedding channels and continuous features in a unified call.
Model debugging: diagnose activation saturation and verify the completeness axiom.
- Key Features:
Dual operating modes for embedding-based or continuous inputs.
Automatic activation module swapping for DeepLIFT Rescale rule.
Completeness enforcement ensuring
sum(attribution) ~= f(x) - f(x0).Batch-friendly API accepting trainer-style dictionaries with tuple-based inputs following processor schemas.
Target control via
target_class_idxto explain any desired logit.Mixed token/continuous feature support using
is_token()processor introspection.
- Usage Notes:
Choose a baseline dictionary that reflects a neutral clinical state when zeros are not meaningful.
Move inputs, baselines, and the model to the same device before calling
attribute.Keep
use_embeddings=Truefor token indices; set it toFalseto attribute continuous tensors directly.Call
model.eval()so stochastic layers remain deterministic during paired forward passes.
- Parameters:
model (
BaseModel) – ABaseModelinstance exposing eitherforward_from_embedding()(for discrete inputs) or the standardforward()used by PyHealth trainers.use_embeddings (
bool) – Whether to operate in embedding space. Set toTrue(default) for tokenized inputs orFalseto attribute continuous tensors directly.
Examples
>>> import torch >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> from pyhealth.interpret.methods.deeplift import DeepLift >>> from pyhealth.models import MLP >>> >>> samples = [ ... {"patient_id": "p0", "visit_id": "v0", ... "conditions": ["cond-33", "cond-86", "cond-80"], ... "procedures": [1.0, 2.0, 3.5, 4.0], "label": 1}, ... {"patient_id": "p1", "visit_id": "v1", ... "conditions": ["cond-55", "cond-12"], ... "procedures": [5.0, 2.0, 3.5, 4.0], "label": 0}, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "procedures": "tensor"}, ... output_schema={"label": "binary"}, ... ) >>> model = MLP(dataset=dataset, embedding_dim=32, hidden_dim=32) >>> model.eval() >>> test_loader = get_dataloader(dataset, batch_size=1, shuffle=False) >>> deeplift = DeepLift(model, use_embeddings=True) >>> >>> batch = next(iter(test_loader)) >>> attributions = deeplift.attribute(**batch) >>> print({k: v.shape for k, v in attributions.items()})
- Algorithm Details:
Run a baseline forward pass while caching activations for supported nonlinearities.
Replay the actual inputs with Rescale hooks that substitute secant slopes for local derivatives.
Backpropagate the target logit so gradients equal DeepLIFT multipliers.
Multiply input differences by the propagated multipliers and enforce completeness.
References
- [1] Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. Learning
Important Features through Propagating Activation Differences. Proceedings of the 34th International Conference on Machine Learning (ICML), 2017. https://proceedings.mlr.press/v70/shrikumar17a.html
- attribute(baseline=None, target_class_idx=None, **kwargs)[source]#
Compute DeepLIFT attributions for a single batch.
The method follows Algorithm 2 of the DeepLIFT paper: two forward passes (baseline then actual) are executed under the hook context so that backward propagation yields multipliers equal to the Rescale rule.
- Parameters:
baseline (
Optional[Dict[str,Tensor]]) – Optional dictionary providing reference inputs per feature key. If omitted, UNK tokens are used for discrete features and near-zero values for continuous features.target_class_idx (
Optional[int]) – Optional class index to explain.Nonedefaults to the model prediction.**kwargs (
Tensor|tuple[Tensor,...]) –Input data dictionary from a dataloader batch containing: - Feature keys (e.g., ‘conditions’, ‘procedures’):
Input tensors or tuples of tensors for each modality
’label’ (optional): Ground truth label tensor
Other metadata keys are ignored
- Return type:
- Returns:
Dict[str, torch.Tensor]mapping each feature key to attribution tensors shaped like the original inputs. All outputs satisfy the completeness propertysum_i attribution_i ≈ f(x) - f(x₀).