Source code for pyhealth.interpret.methods.gim

from __future__ import annotations

import math
import contextlib
from typing import Dict, List, Optional, Tuple, Type

import torch
import torch.nn.functional as F

from pyhealth.models import BaseModel
from pyhealth.interpret.api import Interpretable
from .base_interpreter import BaseInterpreter


def _iter_child_modules(module: torch.nn.Module):
    for name, child in module.named_children():
        yield module, name, child
        yield from _iter_child_modules(child)


class _FrozenLayerNorm(torch.nn.Module):
    """LayerNorm replacement that treats normalization statistics as constants.

    Implements the LayerNorm freeze rule from GIM Sec. 4.2: in the forward
    pass the output is identical to ``nn.LayerNorm``, but the backward pass
    treats the mean and variance as fixed constants (i.e. their Jacobian
    contributions are detached).
    """

    def __init__(self, original: torch.nn.LayerNorm):
        super().__init__()
        self.original = original

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # noqa: D401
        out = _FrozenLayerNormFn.apply(
            x,
            self.original.normalized_shape,
            self.original.weight,
            self.original.bias,
            self.original.eps,
        )
        assert isinstance(out, torch.Tensor)
        return out


class _FrozenLayerNormFn(torch.autograd.Function):
    """Custom autograd: forward == LayerNorm, backward freezes statistics."""

    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,
        normalized_shape: Tuple[int, ...],
        weight: Optional[torch.Tensor],
        bias: Optional[torch.Tensor],
        eps: float,
    ) -> torch.Tensor:
        # Compute the standard LayerNorm output.
        dims = tuple(range(-len(normalized_shape), 0))
        mean = x.mean(dim=dims, keepdim=True)
        var = x.var(dim=dims, unbiased=False, keepdim=True)
        x_hat = (x - mean) / torch.sqrt(var + eps)
        out = x_hat
        if weight is not None:
            out = out * weight
        if bias is not None:
            out = out + bias
        # Save what we need for backward – mean and std are treated as
        # constants, so we only need x_hat, weight, and 1/std.
        inv_std = 1.0 / torch.sqrt(var + eps)
        ctx.save_for_backward(x_hat, weight, inv_std)
        return out

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):  # type: ignore[override]
        x_hat, weight, inv_std = ctx.saved_tensors
        # Treat mean/var as frozen constants → ∂out/∂x = weight / std
        # (no correction terms from differentiating through mean/var).
        if weight is not None:
            grad_x = grad_output * weight * inv_std
        else:
            grad_x = grad_output * inv_std

        # Gradients for affine parameters (weight, bias) are standard.
        grad_weight = None
        if weight is not None:
            grad_weight = (grad_output * x_hat).flatten(end_dim=-len(weight.shape) - 1).sum(0)
        grad_bias = None
        if ctx.saved_tensors[1] is not None:  # bias was provided
            grad_bias = grad_output.flatten(end_dim=-len(weight.shape) - 1).sum(0)

        return grad_x, None, grad_weight, grad_bias, None


class _MatMulNorm(torch.autograd.Function):
    """matmul whose backward divides grad by fan-in (=2 for a binary product).

    Implements the uniform division rule from GIM Sec. 4.2 for a single
    matrix multiplication.  Used inside :class:`_AttentionGIM` to normalise
    gradients flowing through Q·K^T and attn·V products.

    Because the /2 is applied *per matmul*, the effective normalisation
    compounds across the two sequential multiplications in attention:

    * **V** participates in one matmul (attn·V)  → effective /2.
    * **Q** passes through two matmuls (Q·K^T → softmax → attn·V) →
      effective /4  (the /2 from attn·V propagates through softmax's
      linear Jacobian, then a second /2 comes from Q·K^T).
    * **K** — same as Q → effective /4.

    This matches the reference implementation (JoakimEdin/gim,
    ``_grad_normalize``: key÷4, query÷4, value÷2).
    """

    @staticmethod
    def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(a, b)
        return a @ b

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):  # type: ignore[override]
        a, b = ctx.saved_tensors
        grad_a = (grad_output @ b.transpose(-2, -1)) / 2.0
        grad_b = (a.transpose(-2, -1) @ grad_output) / 2.0
        return grad_a, grad_b


class _AttentionGIM(torch.nn.Module):
    """Drop-in replacement for ``Attention`` that applies GIM rules 1 & 3.

    1. **TSG** – the softmax in the forward pass is computed normally, but
       the backward Jacobian uses a higher temperature (Sec. 4.1).
    3. **Gradient normalisation** – both ``matmul(Q, K^T)`` and
       ``matmul(attn, V)`` use the uniform division rule so that each
       factor receives half of the incoming gradient (Sec. 4.2).

    This module mirrors the signature of
    ``pyhealth.models.transformer.Attention`` so it can be swapped in and
    out by :class:`_GIMSwapContext` without touching any global state.
    """

    def __init__(self, temperature: float):
        super().__init__()
        self.temperature = max(float(temperature), 1.0)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        dropout: Optional[torch.nn.Module] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # --- scores = Q · K^T / sqrt(d_k) with gradient normalisation ---
        qk = _MatMulNorm.apply(query, key.transpose(-2, -1))
        assert isinstance(qk, torch.Tensor)
        scores = qk / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # --- softmax with TSG ---
        p_attn: torch.Tensor = _TemperatureSoftmaxFn.apply(scores, -1, self.temperature)  # type: ignore[assignment]

        if mask is not None:
            p_attn = p_attn.masked_fill(mask == 0, 0)
        if dropout is not None:
            p_attn = dropout(p_attn)

        # --- attn · V with gradient normalisation ---
        out = _MatMulNorm.apply(p_attn, value)
        assert isinstance(out, torch.Tensor)
        return out, p_attn


class _GIMSwapContext(contextlib.AbstractContextManager):
    """Temporarily replace Attention, Softmax and LayerNorm modules with GIM-aware versions."""

    def __init__(self, model: BaseModel, temperature: float):
        self.model = model
        self.temperature = temperature
        self._swapped: List[Tuple[torch.nn.Module, str, torch.nn.Module]] = []

    def __enter__(self) -> "_GIMSwapContext":
        for parent, name, child in _iter_child_modules(self.model):
            # Swap Attention modules inside MultiHeadedAttention –
            # this subsumes both the softmax (TSG) and the matmul
            # (gradient normalisation) rules for attention.
            if self._is_attention_module(child):
                wrapper = _AttentionGIM(temperature=self.temperature)
                setattr(parent, name, wrapper)
                self._swapped.append((parent, name, child))
            # Swap nn.LayerNorm modules (LN freeze rule).
            elif isinstance(child, torch.nn.LayerNorm):
                wrapper = _FrozenLayerNorm(child)
                setattr(parent, name, wrapper)
                self._swapped.append((parent, name, child))
        return self

    def __exit__(self, exc_type, exc, exc_tb) -> bool:
        for parent, name, original in reversed(self._swapped):
            setattr(parent, name, original)
        self._swapped.clear()
        return False

    @staticmethod
    def _is_attention_module(module: torch.nn.Module) -> bool:
        """Return True for PyHealth's ``Attention`` (scaled dot-product helper)."""
        cls = type(module)
        return (
            cls.__name__ == "Attention"
            and hasattr(module, "softmax")
            and isinstance(getattr(module, "softmax"), torch.nn.Softmax)
        )


class _TemperatureSoftmaxFn(torch.autograd.Function):
    """Custom autograd op implementing temperature-adjusted softmax gradients.

    Implements the Temperature-Scaled Gradients (TSG) rule from GIM Sec. 4.1 by
    recomputing the backward Jacobian with a higher temperature while leaving
    the forward softmax unchanged.
    """

    @staticmethod
    def forward(
        ctx,
        input_tensor: torch.Tensor,
        dim: int,
        temperature: float,
    ) -> torch.Tensor:
        ctx.dim = dim
        ctx.temperature = float(temperature)
        ctx.save_for_backward(input_tensor)
        return torch.softmax(input_tensor, dim=dim)

    @staticmethod
    def backward( # type: ignore[return]
        ctx,
        grad_output: torch.Tensor,
    ) -> Tuple[torch.Tensor, None, None]: 
        (input_tensor,) = ctx.saved_tensors
        dim = ctx.dim
        temperature = max(ctx.temperature, 1.0)

        if temperature == 1.0:
            probs = torch.softmax(input_tensor, dim=dim)
            dot = (grad_output * probs).sum(dim=dim, keepdim=True)
            grad_input = probs * (grad_output - dot)
            return grad_input, None, None

        # TSG: recompute softmax at higher temperature, then use the
        # *standard* softmax Jacobian formula evaluated at the
        # temperature-adjusted distribution.  Crucially, we do NOT
        # multiply by 1/T (the chain-rule factor for x/T) — TSG is
        # defined as "change the point at which the Jacobian is
        # evaluated", not "compute the full derivative of softmax(x/T)".
        # This matches the reference implementation (softmax_tsg in
        # JoakimEdin/gim, utils.py).
        adjusted = torch.softmax(input_tensor / temperature, dim=dim)
        dot = (grad_output * adjusted).sum(dim=dim, keepdim=True)
        grad_input = adjusted * (grad_output - dot)
        return grad_input, None, None


class _GIMHookContext(contextlib.AbstractContextManager):
    """Context manager that installs all GIM backward-pass modifications.

    Activates three mechanisms when entered (all via module swapping):
    1. Temperature-adjusted softmax (TSG) — ``nn.Softmax`` and ``Attention``.
    2. Frozen LayerNorm — ``nn.LayerNorm``.
    3. Gradient normalisation for Q·K^T and attn·V — ``Attention``.
    """

    def __init__(self, model: BaseModel, temperature: float):
        self.model = model
        self.temperature = temperature
        self._swap_ctx = _GIMSwapContext(model, temperature=max(float(temperature), 1.0))

    def __enter__(self) -> "_GIMHookContext":
        self._swap_ctx.__enter__()
        return self

    def __exit__(self, exc_type, exc, exc_tb) -> bool:
        self._swap_ctx.__exit__(exc_type, exc, exc_tb)
        return False


[docs]class GIM(BaseInterpreter): """Gradient Interaction Modifications for StageNet-style and Transformer models. This interpreter adapts the Gradient Interaction Modifications (GIM) technique (Edin et al., 2025) to PyHealth. It supports both recurrent models such as StageNet (where cumulative softmax can exhibit self-repair) and Transformer / attention-based architectures (where LayerNorm and Q·K^T interactions require special treatment). The implementation follows three rules from the paper: 1. **Temperature-adjusted softmax gradients (TSG):** All ``nn.Softmax`` modules are temporarily replaced so the backward Jacobian is recomputed at a higher temperature, exposing interactions hidden by softmax redistribution (Sec. 4.1). 2. **LayerNorm freeze:** ``nn.LayerNorm`` modules are replaced with a variant that treats the running mean and variance as frozen constants during backpropagation. For models without LayerNorm (e.g. StageNet) this is a no-op (Sec. 4.2). 3. **Gradient normalization (uniform division):** ``torch.matmul`` calls inside attention layers (the Q·K^T product) are wrapped so that gradients flowing through the binary product are divided by 2. Thanks to composition across the two matmuls in attention, Q and K effectively receive /4 and V receives /2, matching the reference implementation. For models without multi-head attention (e.g. StageNet) this is a no-op (Sec. 4.2). .. note:: The paper also mentions a third multiplicative interaction (MLP gate-projection) that is relevant for gated FFNs (e.g. SwiGLU). PyHealth's ``PositionwiseFeedForward`` uses a standard two-layer FFN with GELU (no element-wise gate), so this normalisation is not needed and is intentionally omitted. Args: model: Trained PyHealth model supporting ``forward_from_embedding`` and ``get_embedding_model()``. Currently tested with StageNet, StageNetMHA, and Transformer. temperature: Softmax temperature used exclusively for the backward pass. A value of ``2.0`` matches the paper's best setting. Examples: >>> import torch >>> from pyhealth.datasets import get_dataloader >>> from pyhealth.interpret.methods.gim import GIM >>> from pyhealth.models import StageNet >>> >>> # Assume ``sample_dataset`` and trained StageNet weights are available. >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> model = StageNet(dataset=sample_dataset) >>> model = model.to(device).eval() >>> test_loader = get_dataloader(sample_dataset, batch_size=1, shuffle=False) >>> gim = GIM(model, temperature=2.0) >>> >>> batch = next(iter(test_loader)) >>> attributions = gim.attribute(**batch) >>> print({k: v.shape for k, v in attributions.items()}) """ def __init__( self, model: BaseModel, temperature: float = 2.0, ): super().__init__(model) if not isinstance(model, Interpretable): raise ValueError("Model must implement Interpretable interface") self.model = model self.temperature = max(float(temperature), 1.0)
[docs] def attribute( self, target_class_idx: Optional[int] = None, **kwargs: torch.Tensor | tuple[torch.Tensor, ...], ) -> Dict[str, torch.Tensor]: """Compute GIM attributions for a batch. Args: target_class_idx: Target class index for attribution. For binary classification (single logit output), this is a no-op. If None, uses the argmax of model output. **kwargs: Input data dictionary from a dataloader batch containing feature tensors or tuples of tensors for each modality, plus optional label tensors. Returns: Dictionary mapping feature keys to attribution tensors with the same shape as the raw input values. """ device = next(self.model.parameters()).device # Filter kwargs to only include model feature keys and ensure tuples inputs = { k: (v,) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items() if k in self.model.feature_keys } # Disassemble inputs to get values and masks via processor schema values: dict[str, torch.Tensor] = {} masks: dict[str, torch.Tensor] = {} for k, v in inputs.items(): schema = self.model.dataset.input_processors[k].schema() values[k] = v[schema.index("value")] if "mask" in schema: masks[k] = v[schema.index("mask")] else: val = v[schema.index("value")] processor = self.model.dataset.input_processors[k] if processor.is_token(): masks[k] = (val != 0).int() else: # For continuous features, check whether the entire # feature vector at each timestep is zero (padding) # rather than per-element, so valid 0.0 values are # not masked out. if val.dim() >= 3: masks[k] = (val.abs().sum(dim=-1) != 0).int() else: masks[k] = (val != 0).int() # Append input masks to inputs for models that expect them for k, v in inputs.items(): if "mask" not in self.model.dataset.input_processors[k].schema(): inputs[k] = (*v, masks[k]) # Save raw shapes before embedding for later mapping shapes = {k: v.shape for k, v in values.items()} # Determine target class from original input with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] target_indices = self._resolve_target_indices(base_logits, target_class_idx) # Embed values and detach for gradient attribution. # Split features by type using is_token(): # - Token features (discrete): embed before gradient computation, # since raw indices are not differentiable. Gradients are computed # w.r.t. embeddings, then summed over the embedding dim. # - Continuous features: keep raw so each raw dimension gets its own # gradient-based attribution. The embedding happens inside the # forward pass via the embedding model. embedding_model = self.model.get_embedding_model() assert embedding_model is not None token_keys = { k for k in values if self.model.dataset.input_processors[k].is_token() } continuous_keys = set(values.keys()) - token_keys # Embed token features if token_keys: token_embedded = embedding_model({k: values[k] for k in token_keys}) else: token_embedded = {} # Prepare gradient targets: embeddings for tokens, raw values for continuous embeddings: dict[str, torch.Tensor] = {} for key in sorted(values.keys()): if key in token_keys: emb = token_embedded[key] else: emb = values[key].to(device).float() emb = emb.detach().requires_grad_(True) emb.retain_grad() embeddings[key] = emb # Insert gradient targets back into input tuples. # For continuous features, we also need to embed them for # forward_from_embedding, but we keep the raw tensor as the # gradient target so attributions have per-raw-feature granularity. forward_inputs = inputs.copy() for k in forward_inputs.keys(): schema = self.model.dataset.input_processors[k].schema() val_idx = schema.index("value") if k in continuous_keys: # Embed the raw tensor through the embedding model; # autograd will track gradients back to the raw tensor. embedded_val = embedding_model({k: embeddings[k]})[k] forward_inputs[k] = ( *forward_inputs[k][:val_idx], embedded_val, *forward_inputs[k][val_idx + 1:], ) else: forward_inputs[k] = ( *forward_inputs[k][:val_idx], embeddings[k], *forward_inputs[k][val_idx + 1:], ) # Clear stale gradients before the attribution pass. self.model.zero_grad(set_to_none=True) # All three GIM rules are applied via _GIMHookContext: # Step 1 (TSG): nn.Softmax → temperature-adjusted backward. # Step 2 (LayerNorm freeze): nn.LayerNorm → frozen statistics. # Step 3 (Gradient normalization): torch.matmul → uniform division # for Q·K^T in attention layers. # The context manager detects which rules are applicable to the model # and only activates the relevant ones. with _GIMHookContext(self.model, self.temperature): output = self.model.forward_from_embedding(**forward_inputs) logits = output["logit"] # type: ignore[assignment] target_output = self._compute_target_output(logits, target_indices) # Clear stale gradients, then backpropagate through the # GIM-modified computational graph. self.model.zero_grad(set_to_none=True) for emb in embeddings.values(): if emb.grad is not None: emb.grad.zero_() target_output.backward() attributions: dict[str, torch.Tensor] = {} for key, emb in embeddings.items(): grad = emb.grad if grad is None: grad = torch.zeros_like(emb) # Sum embedding dimension to get per-token attribution # (only for token features that were embedded before gradient computation) attr = grad.detach() if key in token_keys and attr.dim() >= 3: attr = attr.sum(dim=-1) attributions[key] = attr return self._map_to_input_shapes(attributions, shapes)
# ------------------------------------------------------------------ # Target output computation # ------------------------------------------------------------------ def _compute_target_output( self, logits: torch.Tensor, target_indices: torch.Tensor, ) -> torch.Tensor: """Compute scalar target output for backpropagation. Selects the target-class logit for each sample and sums over the batch to produce a single differentiable scalar. Args: logits: Model output logits, shape [batch, num_classes]. target_indices: [batch] tensor of target class indices. Returns: Scalar tensor for backpropagation. """ return logits.gather( 1, target_indices.unsqueeze(1) ).squeeze(1).sum() # ------------------------------------------------------------------ # Utility helpers # ------------------------------------------------------------------ @staticmethod def _map_to_input_shapes( attr_values: Dict[str, torch.Tensor], input_shapes: dict, ) -> Dict[str, torch.Tensor]: """Map attributions back to original input tensor shapes. For embedding-based attributions, the embedding dimension has already been summed out. This method handles any remaining shape mismatches (e.g., expanding scalar attributions to match multi-dimensional inputs). Args: attr_values: Dictionary of attribution tensors. input_shapes: Dictionary of original input shapes. Returns: Dictionary of attributions reshaped to match original inputs. """ mapped: dict[str, torch.Tensor] = {} for key, values in attr_values.items(): if key not in input_shapes: mapped[key] = values continue orig_shape = input_shapes[key] if values.shape == orig_shape: mapped[key] = values continue reshaped = values while len(reshaped.shape) < len(orig_shape): reshaped = reshaped.unsqueeze(-1) if reshaped.shape != orig_shape: reshaped = reshaped.expand(orig_shape) mapped[key] = reshaped return mapped