Source code for pyhealth.interpret.methods.deeplift

from __future__ import annotations

import contextlib
from typing import Dict, List, Optional, Tuple, Type, cast

import torch
import torch.nn.functional as F

from pyhealth.models import BaseModel
from pyhealth.interpret.api import Interpretable
from .base_interpreter import BaseInterpreter


def _iter_child_modules(module: torch.nn.Module):
    for name, child in module.named_children():
        yield module, name, child
        yield from _iter_child_modules(child)


class _HookedModule(torch.nn.Module):
    """Wrap an activation module to route through DeepLIFT hooks."""

    def __init__(self, hook_name: str, hooks: "_DeepLiftActivationHooks", forward_kwargs: Optional[Dict] = None):
        super().__init__()
        self.hook_name = hook_name
        self.hooks = hooks
        self.forward_kwargs = forward_kwargs or {}

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return self.hooks.apply(self.hook_name, tensor, **self.forward_kwargs)


class _ActivationSwapContext(contextlib.AbstractContextManager):
    """Temporarily replace activation modules with DeepLIFT-aware wrappers."""

    _TARGETS: Dict[Type[torch.nn.Module], Tuple[str, Dict]] = {
        torch.nn.ReLU: ("relu", {}),
        torch.nn.Sigmoid: ("sigmoid", {}),
        torch.nn.Tanh: ("tanh", {}),
    }

    def __init__(self, model: BaseModel):
        self.model = model
        self.hooks = _DeepLiftActivationHooks()
        self._swapped: List[Tuple[torch.nn.Module, str, torch.nn.Module]] = []

    def __enter__(self) -> "_ActivationSwapContext":
        for parent, name, child in _iter_child_modules(self.model):
            for target_cls, (hook_name, fkwargs) in self._TARGETS.items():
                if isinstance(child, target_cls):
                    wrapper = _HookedModule(hook_name, self.hooks, fkwargs)
                    setattr(parent, name, wrapper)
                    self._swapped.append((parent, name, child))
                    break
        return self

    def __exit__(self, exc_type, exc, exc_tb) -> bool:
        for parent, name, original in reversed(self._swapped):
            setattr(parent, name, original)
        self._swapped.clear()
        self.hooks.reset()
        return False


class _DeepLiftActivationHooks:
    """Capture activation pairs for baseline and actual forward passes.

    During the baseline forward pass (reference inputs) the hook stores the
    pre-activation and post-activation tensors. During the actual forward
    pass it registers a backward hook that replaces the local derivative with
    the Rescale multiplier ``delta_out / delta_in`` as prescribed by the
    original DeepLIFT paper (Algorithm 1, lines 8–11).

    Only elementwise activations are currently supported because their
    secant slope can be derived analytically. For other operations the code
    falls back to autograd gradients, which coincides with the "linear rule"
    in the paper.
    """

    _SUPPORTED = {"relu", "sigmoid", "tanh"}

    def __init__(self, eps: float = 1e-7):
        self.eps = eps
        self.mode: str = "inactive"
        self.records: Dict[str, list] = {name: [] for name in self._SUPPORTED}
        self._indices: Dict[str, int] = {name: 0 for name in self._SUPPORTED}

    # ------------------------------------------------------------------
    # State management
    # ------------------------------------------------------------------
    def reset(self) -> None:
        """Clear cached activation pairs and return to the inactive state."""
        for name in self.records:
            self.records[name].clear()
            self._indices[name] = 0
        self.mode = "inactive"

    def start_baseline(self) -> None:
        """Begin recording activations for the reference forward pass."""
        self.reset()
        self.mode = "baseline"

    def start_actual(self) -> None:
        """Switch to the actual input forward pass and prepare replaying hooks."""
        if self.mode != "baseline":
            raise RuntimeError("Baseline forward pass must run before actual pass for DeepLIFT.")
        self.mode = "actual"
        for name in self._indices:
            self._indices[name] = 0

    # ------------------------------------------------------------------
    # Activation routing
    # ------------------------------------------------------------------
    def apply(self, name: str, tensor: torch.Tensor, **kwargs) -> torch.Tensor:
        """Run the activation and optionally register a Rescale hook.

        Args:
            name: The ``torch`` activation name (e.g., ``"relu"``).
            tensor: Pre-activation tensor.
            **kwargs: Keyword arguments forwarded to the activation.

        Returns:
            The activation output. If ``mode`` is ``"baseline"`` the output is
            detached and cached; if ``mode`` is ``"actual"`` a backward hook is
            registered so that gradients are rescaled to secant multipliers.
        """
        fn = getattr(torch, name)
        output = fn(tensor, **kwargs)

        if name not in self.records or self.mode == "inactive":
            return output

        if self.mode == "baseline":
            self.records[name].append(
                {
                    "baseline_input": tensor.detach(),
                    "baseline_output": output.detach(),
                }
            )
        elif self.mode == "actual":
            idx = self._indices[name]
            if idx >= len(self.records[name]):
                raise RuntimeError(
                    f"DeepLIFT activation mismatch for '{name}'. Baseline and actual passes "
                    "must trigger hooks in the same order."
                )

            record = self.records[name][idx]
            record["input"] = tensor
            record["output"] = output
            self._indices[name] += 1

            if output.requires_grad:
                self._register_hook(name, record)

        return output

    # ------------------------------------------------------------------
    # Gradient replacement helpers
    # ------------------------------------------------------------------
    def _register_hook(self, name: str, record: Dict[str, torch.Tensor]) -> None:
        """Attach a backward hook implementing the Rescale multiplier.

        The multiplier ``m`` from the paper is computed as the ratio between
        the output and input differences. We apply this multiplier by scaling
        the autograd derivative so that the product equals ``m``.
        """
        input_tensor = record["input"]
        baseline_input = record["baseline_input"].to(input_tensor.device)
        output_tensor = record["output"]
        baseline_output = record["baseline_output"].to(output_tensor.device)

        delta_in = input_tensor - baseline_input
        delta_out = output_tensor - baseline_output

        derivative = self._activation_derivative(name, input_tensor, output_tensor)
        secant = self._safe_div(delta_out, delta_in, derivative)
        scale = self._safe_div(secant, derivative, torch.ones_like(secant))

        # Clamp to finite values to avoid propagating NaNs/Infs downstream
        scale = torch.where(torch.isfinite(scale), scale, torch.ones_like(scale))
        scale = scale.detach()

        def hook_fn(grad: torch.Tensor) -> torch.Tensor:
            """Scale the upstream gradient to equal the Rescale multiplier."""
            return grad * scale

        output_tensor.register_hook(hook_fn)

    def _activation_derivative(
        self, name: str, input_tensor: torch.Tensor, output_tensor: torch.Tensor
    ) -> torch.Tensor:
        """Return the analytical derivative of supported activations."""
        if name == "relu":
            return torch.where(input_tensor > 0, torch.ones_like(output_tensor), torch.zeros_like(output_tensor))
        if name == "sigmoid":
            return output_tensor * (1.0 - output_tensor)
        if name == "tanh":
            return 1.0 - output_tensor.pow(2)
        # Default derivative for unsupported activations
        return torch.ones_like(output_tensor)

    def _safe_div(
        self,
        numerator: torch.Tensor,
        denominator: torch.Tensor,
        fallback: torch.Tensor,
    ) -> torch.Tensor:
        mask = denominator.abs() > self.eps
        safe_denominator = torch.where(mask, denominator, torch.ones_like(denominator))
        quotient = numerator / safe_denominator
        return torch.where(mask, quotient, fallback)


class _DeepLiftHookContext(contextlib.AbstractContextManager):
    """Context manager that swaps activations for DeepLIFT without model hooks."""

    def __init__(self, model: BaseModel):
        self.model = model
        self._swap_ctx = _ActivationSwapContext(model)

    def __enter__(self) -> "_DeepLiftHookContext":
        self._swap_ctx.__enter__()
        return self

    def start_baseline(self) -> None:
        self._swap_ctx.hooks.start_baseline()

    def start_actual(self) -> None:
        self._swap_ctx.hooks.start_actual()

    def __exit__(self, exc_type, exc, exc_tb) -> bool:
        self._swap_ctx.__exit__(exc_type, exc, exc_tb)
        return False


[docs]class DeepLift(BaseInterpreter): """DeepLIFT attribution for PyHealth models. Paper: Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. Learning Important Features through Propagating Activation Differences. ICML 2017. DeepLIFT propagates difference-from-baseline activations using Rescale multipliers so that feature attributions sum to the change in model output. The implementation injects secant slopes for supported activations (ReLU, Sigmoid, Tanh) via module swapping to mirror the original algorithm while falling back to autograd gradients for unsupported operations. This method is particularly useful for: - EHR feature importance: highlight influential visits, codes, or labs when auditing StageNet-style models. - Contrastive explanations: compare predictions against a clinically meaningful baseline patient trajectory. - Mixed-input attribution: handle discrete embedding channels and continuous features in a unified call. - Model debugging: diagnose activation saturation and verify the completeness axiom. Key Features: - Dual operating modes for embedding-based or continuous inputs. - Automatic activation module swapping for DeepLIFT Rescale rule. - Completeness enforcement ensuring ``sum(attribution) ~= f(x) - f(x0)``. - Batch-friendly API accepting trainer-style dictionaries with tuple-based inputs following processor schemas. - Target control via ``target_class_idx`` to explain any desired logit. - Mixed token/continuous feature support using ``is_token()`` processor introspection. Usage Notes: 1. Choose a baseline dictionary that reflects a neutral clinical state when zeros are not meaningful. 2. Move inputs, baselines, and the model to the same device before calling ``attribute``. 3. Keep ``use_embeddings=True`` for token indices; set it to ``False`` to attribute continuous tensors directly. 4. Call ``model.eval()`` so stochastic layers remain deterministic during paired forward passes. Args: model: A :class:`~pyhealth.models.BaseModel` instance exposing either :meth:`forward_from_embedding` (for discrete inputs) or the standard :meth:`forward` used by PyHealth trainers. use_embeddings: Whether to operate in embedding space. Set to ``True`` (default) for tokenized inputs or ``False`` to attribute continuous tensors directly. Examples: >>> import torch >>> from pyhealth.datasets import create_sample_dataset, get_dataloader >>> from pyhealth.interpret.methods.deeplift import DeepLift >>> from pyhealth.models import MLP >>> >>> samples = [ ... {"patient_id": "p0", "visit_id": "v0", ... "conditions": ["cond-33", "cond-86", "cond-80"], ... "procedures": [1.0, 2.0, 3.5, 4.0], "label": 1}, ... {"patient_id": "p1", "visit_id": "v1", ... "conditions": ["cond-55", "cond-12"], ... "procedures": [5.0, 2.0, 3.5, 4.0], "label": 0}, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"conditions": "sequence", "procedures": "tensor"}, ... output_schema={"label": "binary"}, ... ) >>> model = MLP(dataset=dataset, embedding_dim=32, hidden_dim=32) >>> model.eval() >>> test_loader = get_dataloader(dataset, batch_size=1, shuffle=False) >>> deeplift = DeepLift(model, use_embeddings=True) >>> >>> batch = next(iter(test_loader)) >>> attributions = deeplift.attribute(**batch) >>> print({k: v.shape for k, v in attributions.items()}) Algorithm Details: 1. Run a baseline forward pass while caching activations for supported nonlinearities. 2. Replay the actual inputs with Rescale hooks that substitute secant slopes for local derivatives. 3. Backpropagate the target logit so gradients equal DeepLIFT multipliers. 4. Multiply input differences by the propagated multipliers and enforce completeness. References: [1] Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. Learning Important Features through Propagating Activation Differences. Proceedings of the 34th International Conference on Machine Learning (ICML), 2017. https://proceedings.mlr.press/v70/shrikumar17a.html """ def __init__(self, model: BaseModel, use_embeddings: bool = True): super().__init__(model) if not isinstance(model, Interpretable): raise ValueError("Model must implement Interpretable interface") self.model = model self.use_embeddings = use_embeddings # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def attribute( self, baseline: Optional[Dict[str, torch.Tensor]] = None, target_class_idx: Optional[int] = None, **kwargs: torch.Tensor | tuple[torch.Tensor, ...], ) -> Dict[str, torch.Tensor]: """Compute DeepLIFT attributions for a single batch. The method follows Algorithm 2 of the DeepLIFT paper: two forward passes (baseline then actual) are executed under the hook context so that backward propagation yields multipliers equal to the Rescale rule. Args: baseline: Optional dictionary providing reference inputs per feature key. If omitted, UNK tokens are used for discrete features and near-zero values for continuous features. target_class_idx: Optional class index to explain. ``None`` defaults to the model prediction. **kwargs: Input data dictionary from a dataloader batch containing: - Feature keys (e.g., 'conditions', 'procedures'): Input tensors or tuples of tensors for each modality - 'label' (optional): Ground truth label tensor - Other metadata keys are ignored Returns: ``Dict[str, torch.Tensor]`` mapping each feature key to attribution tensors shaped like the original inputs. All outputs satisfy the completeness property ``sum_i attribution_i ≈ f(x) - f(x₀)``. """ device = next(self.model.parameters()).device # Filter kwargs to only include model feature keys and ensure they are tuples inputs = { k: (v,) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items() if k in self.model.feature_keys } # Disassemble inputs to get values and masks values: dict[str, torch.Tensor] = {} masks: dict[str, torch.Tensor] = {} for k, v in inputs.items(): schema = self.model.dataset.input_processors[k].schema() values[k] = v[schema.index("value")] if "mask" in schema: masks[k] = v[schema.index("mask")] else: val = v[schema.index("value")] processor = self.model.dataset.input_processors[k] if processor.is_token(): masks[k] = (val != 0).int() else: # For continuous features, check whether the entire # feature vector at each timestep is zero (padding) # rather than per-element, so valid 0.0 values are # not masked out. if val.dim() >= 3: masks[k] = (val.abs().sum(dim=-1) != 0).int() else: masks[k] = (val != 0).int() # Append input masks to inputs for models that expect them for k, v in inputs.items(): if "mask" not in self.model.dataset.input_processors[k].schema(): inputs[k] = (*v, masks[k]) # Determine target class from original input with torch.no_grad(): base_logits = self.model.forward(**inputs)["logit"] target_indices = self._resolve_target_indices(base_logits, target_class_idx) # Generate baselines if baseline is None: baselines = self._generate_baseline( values, use_embeddings=self.use_embeddings ) else: baselines = { k: v.to(device) for k, v in baseline.items() if k in self.model.feature_keys } # Save raw shapes before embedding for later mapping shapes = {k: v.shape for k, v in values.items()} # Split features by type using is_token(): # - Token features (discrete): embed before DeepLIFT, since # working with raw indices is meaningless. Gradients are computed # w.r.t. embeddings, then summed over the embedding dim. # - Continuous features: keep raw so each raw dimension gets its # own attribution. The model's forward() handles embedding internally. token_keys: set[str] = set() if self.use_embeddings: embedding_model = self.model.get_embedding_model() assert embedding_model is not None, ( "Model must have an embedding model for embedding-based " "DeepLIFT." ) token_keys = { k for k in values if self.model.dataset.input_processors[k].is_token() } if token_keys: # Embed token values token_values = {k: values[k] for k in token_keys} embedded_tokens = embedding_model(token_values) for k in token_keys: values[k] = embedded_tokens[k] # Embed token baselines so they live in the same space token_baselines = {k: baselines[k] for k in token_keys if k in baselines} if token_baselines: embedded_baselines = embedding_model(token_baselines) for k in token_baselines: baselines[k] = embedded_baselines[k] # Compute DeepLIFT attributions attributions = self._deeplift( inputs=inputs, xs=values, bs=baselines, target_indices=target_indices, token_keys=token_keys, ) return self._map_to_input_shapes(attributions, shapes)
# ------------------------------------------------------------------ # Core DeepLIFT computation # ------------------------------------------------------------------ def _deeplift( self, inputs: Dict[str, tuple[torch.Tensor, ...]], xs: Dict[str, torch.Tensor], bs: Dict[str, torch.Tensor], target_indices: torch.Tensor, token_keys: set[str], ) -> Dict[str, torch.Tensor]: """Core DeepLIFT computation using the Rescale rule. Performs two forward passes (baseline and actual) under the activation swap context, then backpropagates to obtain DeepLIFT multipliers. Args: inputs: Full input tuples keyed by feature name. xs: Input values (embedded if token features with use_embeddings). bs: Baseline values (embedded if token features with use_embeddings). target_indices: [batch] tensor of target class indices. token_keys: Set of feature keys that are token (already embedded). Returns: Dictionary mapping feature keys to attribution tensors. """ keys = sorted(xs.keys()) # Create delta tensors with gradients enabled delta: dict[str, torch.Tensor] = {} current: dict[str, torch.Tensor] = {} for key in keys: d = (xs[key] - bs[key]).detach() d.requires_grad_(True) d.retain_grad() delta[key] = d current[key] = bs[key].detach() + d # Build forward inputs with current (baseline + delta) values # inserted into the proper position in the input tuples def _build_forward_inputs(value_dict: dict[str, torch.Tensor]) -> dict: fwd_inputs = inputs.copy() for k in fwd_inputs.keys(): schema = self.model.dataset.input_processors[k].schema() val_idx = schema.index("value") fwd_inputs[k] = ( *fwd_inputs[k][:val_idx], value_dict[k], *fwd_inputs[k][val_idx + 1:], ) return fwd_inputs # For continuous features, embed them before forward_from_embedding def _maybe_embed_continuous(value_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: if not self.use_embeddings: return value_dict continuous_keys = {k for k in value_dict if k not in token_keys} if not continuous_keys: return value_dict embedding_model = self.model.get_embedding_model() assert embedding_model is not None continuous_to_embed = {k: value_dict[k] for k in continuous_keys} embedded_continuous = embedding_model(continuous_to_embed) return {**value_dict, **embedded_continuous} # Two forward passes under the hook context with _DeepLiftHookContext(self.model) as hook_ctx: # --- Baseline forward pass --- hook_ctx.start_baseline() baseline_values = _maybe_embed_continuous( {k: bs[k].detach() for k in keys} ) baseline_fwd = _build_forward_inputs(baseline_values) with torch.no_grad(): if self.use_embeddings: baseline_output = self.model.forward_from_embedding(**baseline_fwd) else: baseline_output = self.model.forward(**baseline_fwd) # --- Actual forward pass --- hook_ctx.start_actual() current_values = _maybe_embed_continuous(current) current_fwd = _build_forward_inputs(current_values) if self.use_embeddings: current_output = self.model.forward_from_embedding(**current_fwd) else: current_output = self.model.forward(**current_fwd) logits = current_output["logit"] # type: ignore[index] baseline_logits = baseline_output["logit"] # type: ignore[index] # Compute per-sample target outputs target_output = self._compute_target_output(logits, target_indices) baseline_target_output = self._compute_target_output( baseline_logits, target_indices ) self.model.zero_grad(set_to_none=True) target_output.sum().backward() # Collect attributions: grad * delta attributions: dict[str, torch.Tensor] = {} for key in keys: grad = delta[key].grad if grad is None: attributions[key] = torch.zeros_like(delta[key]) else: attr = grad.detach() * delta[key].detach() # For token features, sum over the embedding dimension if self.use_embeddings and key in token_keys and attr.dim() >= 3: attr = attr.sum(dim=-1) attributions[key] = attr # Enforce completeness: sum of attributions == f(x) - f(x0) attributions = self._enforce_completeness( attributions, target_output.detach(), baseline_target_output.detach(), ) return attributions # ------------------------------------------------------------------ # Target output computation # ------------------------------------------------------------------ def _compute_target_output( self, logits: torch.Tensor, target_indices: torch.Tensor, ) -> torch.Tensor: """Compute per-sample target output. Selects the target-class logit for each sample. Args: logits: Model output logits, shape [batch, num_classes]. target_indices: [batch] tensor of target class indices. Returns: Per-sample target output tensor, shape [batch]. """ return logits.gather( 1, target_indices.unsqueeze(1) ).squeeze(1) # ------------------------------------------------------------------ # Completeness enforcement # ------------------------------------------------------------------ @staticmethod def _enforce_completeness( contributions: Dict[str, torch.Tensor], target_output: torch.Tensor, baseline_output: torch.Tensor, eps: float = 1e-8, ) -> Dict[str, torch.Tensor]: """Scale attributions so their sum matches ``f(x) - f(x₀)`` (Eq. 1).""" delta_output = (target_output - baseline_output) total = None for contrib in contributions.values(): flat_sum = contrib.reshape(contrib.size(0), -1).sum(dim=1) total = flat_sum if total is None else total + flat_sum scale = torch.ones_like(delta_output) if total is not None: denom = total mask = denom.abs() > eps scale[mask] = delta_output[mask] / denom[mask] for key, contrib in contributions.items(): reshape_dims = [contrib.size(0)] + [1] * (contrib.dim() - 1) contributions[key] = contrib * scale.view(*reshape_dims) return contributions # ------------------------------------------------------------------ # Baseline generation # ------------------------------------------------------------------ def _generate_baseline( self, values: Dict[str, torch.Tensor], use_embeddings: bool = False, ) -> Dict[str, torch.Tensor]: """Generate raw baselines for DeepLIFT computation. Creates reference samples representing the "absence" of features. The strategy depends on the feature type: - Discrete (token) features: UNK token index (will be embedded later in ``attribute()`` alongside the values) - Continuous features: small near-zero neutral values Args: values: Dictionary of raw input value tensors (before embedding). use_embeddings: If True, generate baselines suitable for embedding-based DeepLIFT. Returns: Dictionary mapping feature names to baseline tensors in raw (pre-embedding) space. Embedding of token baselines is handled by the caller (``attribute()``). """ baselines: dict[str, torch.Tensor] = {} for k, v in values.items(): processor = self.model.dataset.input_processors[k] if use_embeddings and processor.is_token(): # Token features: UNK token index as baseline baseline = torch.ones_like(v) else: # Continuous features (or non-embedding mode): near-zero baseline baseline = torch.zeros_like(v) + 1e-2 baselines[k] = baseline return baselines # ------------------------------------------------------------------ # Utility helpers # ------------------------------------------------------------------ @staticmethod def _map_to_input_shapes( attr_values: Dict[str, torch.Tensor], input_shapes: dict, ) -> Dict[str, torch.Tensor]: """Map attributions back to original input tensor shapes. For embedding-based attributions, the embedding dimension has already been summed out. This method handles any remaining shape mismatches (e.g., expanding scalar attributions to match multi-dimensional inputs). Args: attr_values: Dictionary of attribution tensors. input_shapes: Dictionary of original input shapes. Returns: Dictionary of attributions reshaped to match original inputs. """ mapped: dict[str, torch.Tensor] = {} for key, values in attr_values.items(): if key not in input_shapes: mapped[key] = values continue orig_shape = input_shapes[key] if values.shape == orig_shape: mapped[key] = values continue reshaped = values while len(reshaped.shape) < len(orig_shape): reshaped = reshaped.unsqueeze(-1) if reshaped.shape != orig_shape: reshaped = reshaped.expand(orig_shape) mapped[key] = reshaped return mapped