pyhealth.interpret.methods.gim#
Overview#
The Gradient Interaction Modifications (GIM) interpreter adapts the StageNet attribution method described by Edin et al. (2025). It recomputes softmax gradients with a higher temperature so that token-level interactions remain visible when cumulative softmax layers are present.
Use this interpreter with StageNet-style models that expose
forward_from_embedding and embedding_model.
For a complete working example, see:
examples/gim_stagenet_mimic4.py
API Reference#
- class pyhealth.interpret.methods.GIM(model, temperature=2.0)[source]#
Bases:
BaseInterpreterGradient Interaction Modifications for StageNet-style and Transformer models.
This interpreter adapts the Gradient Interaction Modifications (GIM) technique (Edin et al., 2025) to PyHealth. It supports both recurrent models such as StageNet (where cumulative softmax can exhibit self-repair) and Transformer / attention-based architectures (where LayerNorm and Q·K^T interactions require special treatment).
The implementation follows three rules from the paper:
Temperature-adjusted softmax gradients (TSG): All
nn.Softmaxmodules are temporarily replaced so the backward Jacobian is recomputed at a higher temperature, exposing interactions hidden by softmax redistribution (Sec. 4.1).LayerNorm freeze:
nn.LayerNormmodules are replaced with a variant that treats the running mean and variance as frozen constants during backpropagation. For models without LayerNorm (e.g. StageNet) this is a no-op (Sec. 4.2).Gradient normalization (uniform division):
torch.matmulcalls inside attention layers (the Q·K^T product) are wrapped so that gradients flowing through the binary product are divided by 2. Thanks to composition across the two matmuls in attention, Q and K effectively receive /4 and V receives /2, matching the reference implementation. For models without multi-head attention (e.g. StageNet) this is a no-op (Sec. 4.2).
Note
The paper also mentions a third multiplicative interaction (MLP gate-projection) that is relevant for gated FFNs (e.g. SwiGLU). PyHealth’s
PositionwiseFeedForwarduses a standard two-layer FFN with GELU (no element-wise gate), so this normalisation is not needed and is intentionally omitted.- Parameters:
Examples
>>> import torch >>> from pyhealth.datasets import get_dataloader >>> from pyhealth.interpret.methods.gim import GIM >>> from pyhealth.models import StageNet >>> >>> # Assume ``sample_dataset`` and trained StageNet weights are available. >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> model = StageNet(dataset=sample_dataset) >>> model = model.to(device).eval() >>> test_loader = get_dataloader(sample_dataset, batch_size=1, shuffle=False) >>> gim = GIM(model, temperature=2.0) >>> >>> batch = next(iter(test_loader)) >>> attributions = gim.attribute(**batch) >>> print({k: v.shape for k, v in attributions.items()})
- attribute(target_class_idx=None, **kwargs)[source]#
Compute GIM attributions for a batch.
- Parameters:
target_class_idx (
Optional[int]) – Target class index for attribution. For binary classification (single logit output), this is a no-op. If None, uses the argmax of model output.**kwargs (
Tensor|tuple[Tensor,...]) – Input data dictionary from a dataloader batch containing feature tensors or tuples of tensors for each modality, plus optional label tensors.
- Return type:
- Returns:
Dictionary mapping feature keys to attribution tensors with the same shape as the raw input values.