pyhealth.interpret.utils#
Visualization utilities for interpretability methods.
This module provides visualization functions for interpretability in PyHealth, particularly useful for medical imaging applications. It includes utilities for:
Overlay visualizations: Show attribution/saliency maps on top of images
Attribution normalization: Prepare raw attributions for visualization
Interpolation: Resize patch-level attributions (e.g., from ViT) to image size
Example Usage#
Basic attribution overlay:
>>> from pyhealth.interpret.utils import show_cam_on_image, normalize_attribution
>>> # Assume we have an image and attribution from an interpreter
>>> attr_normalized = normalize_attribution(attribution)
>>> overlay = show_cam_on_image(image, attr_normalized)
Image attribution visualization:
>>> from pyhealth.interpret.methods import CheferRelevance
>>> from pyhealth.interpret.utils import visualize_image_attr
>>> interpreter = CheferRelevance(model)
>>> attribution = interpreter.get_vit_attribution_map(**batch)
>>> image, attr_map, overlay = visualize_image_attr(
... image=batch["image"][0],
... attribution=attribution[0, 0],
... interpolate=True, # Resize attribution to match image
... )
See also
pyhealth.interpret.methodsAttribution methods (DeepLift, IntegratedGradients, etc.)
- pyhealth.interpret.utils.show_cam_on_image(img, mask, use_rgb=True, colormap=None, image_weight=0.5)[source]#
Overlay a Class Activation Map (CAM) or attribution map on an image.
This function creates a visualization by blending an attribution/saliency map with the original image using a colormap (typically ‘jet’ for heatmap visualization).
- Parameters:
img (
ndarray) – Input image as numpy array with shape (H, W, 3) for RGB or (H, W) for grayscale. Values should be in range [0, 1].mask (
ndarray) – Attribution/saliency map with shape (H, W). Values should be in range [0, 1] where higher values indicate more importance.use_rgb (
bool) – If True, return RGB format. If False, return BGR format. Default is True.colormap (
int) – OpenCV colormap constant. If None, uses cv2.COLORMAP_JET. Common options: cv2.COLORMAP_JET, cv2.COLORMAP_HOT, cv2.COLORMAP_VIRIDISimage_weight (
float) – Weight of the original image in the blend (0 to 1). Default is 0.5 for equal blend.
- Return type:
ndarray- Returns:
Blended visualization as uint8 numpy array with shape (H, W, 3) in range [0, 255].
- Raises:
ValueError – If inputs are invalid or cv2 is not available.
Examples
>>> import numpy as np >>> from pyhealth.interpret.utils import show_cam_on_image >>> >>> # Create sample image and attribution >>> image = np.random.rand(224, 224, 3) # RGB image >>> attribution = np.random.rand(224, 224) # Saliency map >>> >>> # Create visualization >>> overlay = show_cam_on_image(image, attribution) >>> overlay.shape (224, 224, 3)
- pyhealth.interpret.utils.interpolate_attribution_map(attribution, target_size, mode='bilinear')[source]#
Interpolate attribution map to target size.
This is useful for models where the attribution is computed at a lower resolution (e.g., 14x14 patch grid for ViT-B/16) and needs to be upsampled to the original image resolution (e.g., 224x224).
- Parameters:
attribution (
ndarray) – Attribution map as numpy array or torch tensor. Shape can be (H, W) or (B, H, W) or (1, 1, H, W).target_size (
Tuple[int,int]) – Target (height, width) for interpolation.mode (
str) – Interpolation mode. Options: “bilinear”, “nearest”. Default is “bilinear” for smooth gradients.
- Return type:
ndarray- Returns:
Interpolated attribution map with shape (target_h, target_w).
Examples
>>> # For ViT-B/16 with 14x14 patch grid >>> attr_patches = np.random.rand(14, 14) >>> attr_full = interpolate_attribution_map(attr_patches, (224, 224)) >>> attr_full.shape (224, 224)
- pyhealth.interpret.utils.normalize_attribution(attribution, method='minmax')[source]#
Normalize attribution values for visualization.
- Parameters:
attribution (
Union[ndarray,Tensor]) – Raw attribution values.method (
str) – Normalization method. Options: - “minmax”: Scale to [0, 1] using min-max normalization - “abs_max”: Scale by absolute maximum, keeping sign - “percentile”: Clip to [5, 95] percentile then normalize
- Return type:
ndarray- Returns:
Normalized attribution as numpy array in [0, 1].
- pyhealth.interpret.utils.visualize_image_attr(image, attribution, normalize=True, interpolate=True)[source]#
Generate visualization components from an image and attribution map.
This is a convenience function that prepares image and attribution for visualization, handling common format conversions, interpolation, and creating an overlay. Works with any image-based model (CNN, ViT, etc.).
- Parameters:
image (
Union[ndarray,Tensor]) – Input image as numpy array or torch tensor. Accepted shapes: [H, W], [H, W, C], [C, H, W]. Values can be in any range (will be normalized to [0, 1]).attribution (
Union[ndarray,Tensor]) – Attribution map as numpy array or torch tensor. Shape should be [H, W]. If different from image size, will be interpolated to match when interpolate=True.normalize (
bool) – If True, normalize attribution to [0, 1] range. Default is True.interpolate (
bool) – If True, interpolate attribution map to match image dimensions if they differ. Default is True.
- Returns:
image: Normalized image as numpy array [H, W] or [H, W, C] in [0, 1]
attribution_map: Attribution as numpy array [H, W] in [0, 1]
overlay: Attribution overlay on image as numpy array [H, W, 3] in [0, 255]
- Return type:
Tuple of (image, attribution_map, overlay) where
Examples
>>> from pyhealth.interpret.methods import CheferRelevance >>> from pyhealth.interpret.utils import visualize_image_attr >>> >>> # Compute attribution with interpreter >>> interpreter = CheferRelevance(model) >>> attr_map = interpreter.get_vit_attribution_map(**batch) >>> >>> # Generate visualization (auto-interpolates to image size) >>> image, attr_display, overlay = visualize_image_attr( ... image=batch["image"][0], ... attribution=attr_map[0, 0], ... interpolate=True, ... ) >>> >>> # Display >>> import matplotlib.pyplot as plt >>> plt.imshow(overlay) >>> plt.savefig("attribution.png")
Overview#
The pyhealth.interpret.utils module provides visualization utilities for
interpretability methods in PyHealth. These functions help create visual
explanations of model predictions, particularly useful for medical imaging.
Core Functions#
Overlay Visualization
show_cam_on_image()- Overlay a CAM/attribution map on an imagevisualize_attribution_on_image()- Generate complete attribution visualization
Normalization & Processing
normalize_attribution()- Normalize attribution values for visualizationinterpolate_attribution_map()- Resize attribution to match image dimensions
Figure Generation
create_attribution_figure()- Create publication-ready figure with overlays
ViT-Specific Functions#
These functions are specifically designed for Vision Transformer (ViT) models
using attention-based interpretability methods like CheferRelevance.
generate_vit_visualization()- Generate visualization components for ViT attributioncreate_vit_attribution_figure()- Create complete ViT attribution figurereshape_vit_attribution()- Reshape flat patch attribution to 2D spatial map
Example: Basic Attribution Visualization#
import numpy as np
from pyhealth.interpret.utils import show_cam_on_image, normalize_attribution
# Assume we have image and attribution from an interpreter
image = np.random.rand(224, 224, 3) # RGB image in [0, 1]
attribution = np.random.rand(224, 224) # Raw attribution values
# Normalize and overlay
attr_normalized = normalize_attribution(attribution)
overlay = show_cam_on_image(image, attr_normalized)
Example: ViT Attribution with CheferRelevance#
from pyhealth.models import TorchvisionModel
from pyhealth.interpret.methods import CheferRelevance
from pyhealth.interpret.utils import (
generate_vit_visualization,
create_vit_attribution_figure,
)
import matplotlib.pyplot as plt
# Initialize ViT model and interpreter
model = TorchvisionModel(dataset, "vit_b_16", {"weights": "DEFAULT"})
# ... train model ...
interpreter = CheferRelevance(model)
# Generate visualization components
image, attr_map, overlay = generate_vit_visualization(
interpreter=interpreter,
**test_batch
)
# Or create a complete figure
fig = create_vit_attribution_figure(
interpreter=interpreter,
class_names={0: "Normal", 1: "COVID", 2: "Pneumonia"},
save_path="vit_attribution.png",
**test_batch
)
See Also#
pyhealth.interpret.methods- Attribution methods (DeepLift, IntegratedGradients, CheferRelevance, etc.)pyhealth.interpret.methods.CheferRelevance- Attention-based interpretability for Transformerspyhealth.models.TorchvisionModel- ViT and other vision models