pyhealth.interpret.utils#

Visualization utilities for interpretability methods.

This module provides visualization functions for interpretability in PyHealth, particularly useful for medical imaging applications. It includes utilities for:

  • Overlay visualizations: Show attribution/saliency maps on top of images

  • Attribution normalization: Prepare raw attributions for visualization

  • Interpolation: Resize patch-level attributions (e.g., from ViT) to image size

Example Usage#

Basic attribution overlay:

>>> from pyhealth.interpret.utils import show_cam_on_image, normalize_attribution
>>> # Assume we have an image and attribution from an interpreter
>>> attr_normalized = normalize_attribution(attribution)
>>> overlay = show_cam_on_image(image, attr_normalized)

Image attribution visualization:

>>> from pyhealth.interpret.methods import CheferRelevance
>>> from pyhealth.interpret.utils import visualize_image_attr
>>> interpreter = CheferRelevance(model)
>>> attribution = interpreter.get_vit_attribution_map(**batch)
>>> image, attr_map, overlay = visualize_image_attr(
...     image=batch["image"][0],
...     attribution=attribution[0, 0],
...     interpolate=True,  # Resize attribution to match image
... )

See also

pyhealth.interpret.methods

Attribution methods (DeepLift, IntegratedGradients, etc.)

pyhealth.interpret.utils.show_cam_on_image(img, mask, use_rgb=True, colormap=None, image_weight=0.5)[source]#

Overlay a Class Activation Map (CAM) or attribution map on an image.

This function creates a visualization by blending an attribution/saliency map with the original image using a colormap (typically ‘jet’ for heatmap visualization).

Parameters:
  • img (ndarray) – Input image as numpy array with shape (H, W, 3) for RGB or (H, W) for grayscale. Values should be in range [0, 1].

  • mask (ndarray) – Attribution/saliency map with shape (H, W). Values should be in range [0, 1] where higher values indicate more importance.

  • use_rgb (bool) – If True, return RGB format. If False, return BGR format. Default is True.

  • colormap (int) – OpenCV colormap constant. If None, uses cv2.COLORMAP_JET. Common options: cv2.COLORMAP_JET, cv2.COLORMAP_HOT, cv2.COLORMAP_VIRIDIS

  • image_weight (float) – Weight of the original image in the blend (0 to 1). Default is 0.5 for equal blend.

Return type:

ndarray

Returns:

Blended visualization as uint8 numpy array with shape (H, W, 3) in range [0, 255].

Raises:

ValueError – If inputs are invalid or cv2 is not available.

Examples

>>> import numpy as np
>>> from pyhealth.interpret.utils import show_cam_on_image
>>>
>>> # Create sample image and attribution
>>> image = np.random.rand(224, 224, 3)  # RGB image
>>> attribution = np.random.rand(224, 224)  # Saliency map
>>>
>>> # Create visualization
>>> overlay = show_cam_on_image(image, attribution)
>>> overlay.shape
(224, 224, 3)
pyhealth.interpret.utils.interpolate_attribution_map(attribution, target_size, mode='bilinear')[source]#

Interpolate attribution map to target size.

This is useful for models where the attribution is computed at a lower resolution (e.g., 14x14 patch grid for ViT-B/16) and needs to be upsampled to the original image resolution (e.g., 224x224).

Parameters:
  • attribution (ndarray) – Attribution map as numpy array or torch tensor. Shape can be (H, W) or (B, H, W) or (1, 1, H, W).

  • target_size (Tuple[int, int]) – Target (height, width) for interpolation.

  • mode (str) – Interpolation mode. Options: “bilinear”, “nearest”. Default is “bilinear” for smooth gradients.

Return type:

ndarray

Returns:

Interpolated attribution map with shape (target_h, target_w).

Examples

>>> # For ViT-B/16 with 14x14 patch grid
>>> attr_patches = np.random.rand(14, 14)
>>> attr_full = interpolate_attribution_map(attr_patches, (224, 224))
>>> attr_full.shape
(224, 224)
pyhealth.interpret.utils.normalize_attribution(attribution, method='minmax')[source]#

Normalize attribution values for visualization.

Parameters:
  • attribution (Union[ndarray, Tensor]) – Raw attribution values.

  • method (str) – Normalization method. Options: - “minmax”: Scale to [0, 1] using min-max normalization - “abs_max”: Scale by absolute maximum, keeping sign - “percentile”: Clip to [5, 95] percentile then normalize

Return type:

ndarray

Returns:

Normalized attribution as numpy array in [0, 1].

pyhealth.interpret.utils.visualize_image_attr(image, attribution, normalize=True, interpolate=True)[source]#

Generate visualization components from an image and attribution map.

This is a convenience function that prepares image and attribution for visualization, handling common format conversions, interpolation, and creating an overlay. Works with any image-based model (CNN, ViT, etc.).

Parameters:
  • image (Union[ndarray, Tensor]) – Input image as numpy array or torch tensor. Accepted shapes: [H, W], [H, W, C], [C, H, W]. Values can be in any range (will be normalized to [0, 1]).

  • attribution (Union[ndarray, Tensor]) – Attribution map as numpy array or torch tensor. Shape should be [H, W]. If different from image size, will be interpolated to match when interpolate=True.

  • normalize (bool) – If True, normalize attribution to [0, 1] range. Default is True.

  • interpolate (bool) – If True, interpolate attribution map to match image dimensions if they differ. Default is True.

Returns:

  • image: Normalized image as numpy array [H, W] or [H, W, C] in [0, 1]

  • attribution_map: Attribution as numpy array [H, W] in [0, 1]

  • overlay: Attribution overlay on image as numpy array [H, W, 3] in [0, 255]

Return type:

Tuple of (image, attribution_map, overlay) where

Examples

>>> from pyhealth.interpret.methods import CheferRelevance
>>> from pyhealth.interpret.utils import visualize_image_attr
>>>
>>> # Compute attribution with interpreter
>>> interpreter = CheferRelevance(model)
>>> attr_map = interpreter.get_vit_attribution_map(**batch)
>>>
>>> # Generate visualization (auto-interpolates to image size)
>>> image, attr_display, overlay = visualize_image_attr(
...     image=batch["image"][0],
...     attribution=attr_map[0, 0],
...     interpolate=True,
... )
>>>
>>> # Display
>>> import matplotlib.pyplot as plt
>>> plt.imshow(overlay)
>>> plt.savefig("attribution.png")

Overview#

The pyhealth.interpret.utils module provides visualization utilities for interpretability methods in PyHealth. These functions help create visual explanations of model predictions, particularly useful for medical imaging.

Core Functions#

Overlay Visualization

  • show_cam_on_image() - Overlay a CAM/attribution map on an image

  • visualize_attribution_on_image() - Generate complete attribution visualization

Normalization & Processing

Figure Generation

  • create_attribution_figure() - Create publication-ready figure with overlays

ViT-Specific Functions#

These functions are specifically designed for Vision Transformer (ViT) models using attention-based interpretability methods like CheferRelevance.

  • generate_vit_visualization() - Generate visualization components for ViT attribution

  • create_vit_attribution_figure() - Create complete ViT attribution figure

  • reshape_vit_attribution() - Reshape flat patch attribution to 2D spatial map

Example: Basic Attribution Visualization#

import numpy as np
from pyhealth.interpret.utils import show_cam_on_image, normalize_attribution

# Assume we have image and attribution from an interpreter
image = np.random.rand(224, 224, 3)  # RGB image in [0, 1]
attribution = np.random.rand(224, 224)  # Raw attribution values

# Normalize and overlay
attr_normalized = normalize_attribution(attribution)
overlay = show_cam_on_image(image, attr_normalized)

Example: ViT Attribution with CheferRelevance#

from pyhealth.models import TorchvisionModel
from pyhealth.interpret.methods import CheferRelevance
from pyhealth.interpret.utils import (
    generate_vit_visualization,
    create_vit_attribution_figure,
)
import matplotlib.pyplot as plt

# Initialize ViT model and interpreter
model = TorchvisionModel(dataset, "vit_b_16", {"weights": "DEFAULT"})
# ... train model ...

interpreter = CheferRelevance(model)

# Generate visualization components
image, attr_map, overlay = generate_vit_visualization(
    interpreter=interpreter,
    **test_batch
)

# Or create a complete figure
fig = create_vit_attribution_figure(
    interpreter=interpreter,
    class_names={0: "Normal", 1: "COVID", 2: "Pneumonia"},
    save_path="vit_attribution.png",
    **test_batch
)

See Also#

  • pyhealth.interpret.methods - Attribution methods (DeepLift, IntegratedGradients, CheferRelevance, etc.)

  • pyhealth.interpret.methods.CheferRelevance - Attention-based interpretability for Transformers

  • pyhealth.models.TorchvisionModel - ViT and other vision models