Source code for pyhealth.interpret.utils

"""Visualization utilities for interpretability methods.

This module provides visualization functions for interpretability in PyHealth,
particularly useful for medical imaging applications. It includes utilities for:

- **Overlay visualizations**: Show attribution/saliency maps on top of images
- **Attribution normalization**: Prepare raw attributions for visualization
- **Interpolation**: Resize patch-level attributions (e.g., from ViT) to image size

Example Usage
-------------

Basic attribution overlay:

>>> from pyhealth.interpret.utils import show_cam_on_image, normalize_attribution
>>> # Assume we have an image and attribution from an interpreter
>>> attr_normalized = normalize_attribution(attribution)
>>> overlay = show_cam_on_image(image, attr_normalized)

Image attribution visualization:

>>> from pyhealth.interpret.methods import CheferRelevance
>>> from pyhealth.interpret.utils import visualize_image_attr
>>> interpreter = CheferRelevance(model)
>>> attribution = interpreter.get_vit_attribution_map(**batch)
>>> image, attr_map, overlay = visualize_image_attr(
...     image=batch["image"][0],
...     attribution=attribution[0, 0],
...     interpolate=True,  # Resize attribution to match image
... )

See Also
--------
pyhealth.interpret.methods : Attribution methods (DeepLift, IntegratedGradients, etc.)
"""

import numpy as np
from typing import Tuple, Union, TYPE_CHECKING

if TYPE_CHECKING:
    import torch

try:
    import cv2
    HAS_CV2 = True
except ImportError:
    HAS_CV2 = False


[docs]def show_cam_on_image( img: np.ndarray, mask: np.ndarray, use_rgb: bool = True, colormap: int = None, image_weight: float = 0.5, ) -> np.ndarray: """Overlay a Class Activation Map (CAM) or attribution map on an image. This function creates a visualization by blending an attribution/saliency map with the original image using a colormap (typically 'jet' for heatmap visualization). Args: img: Input image as numpy array with shape (H, W, 3) for RGB or (H, W) for grayscale. Values should be in range [0, 1]. mask: Attribution/saliency map with shape (H, W). Values should be in range [0, 1] where higher values indicate more importance. use_rgb: If True, return RGB format. If False, return BGR format. Default is True. colormap: OpenCV colormap constant. If None, uses cv2.COLORMAP_JET. Common options: cv2.COLORMAP_JET, cv2.COLORMAP_HOT, cv2.COLORMAP_VIRIDIS image_weight: Weight of the original image in the blend (0 to 1). Default is 0.5 for equal blend. Returns: Blended visualization as uint8 numpy array with shape (H, W, 3) in range [0, 255]. Raises: ValueError: If inputs are invalid or cv2 is not available. Examples: >>> import numpy as np >>> from pyhealth.interpret.utils import show_cam_on_image >>> >>> # Create sample image and attribution >>> image = np.random.rand(224, 224, 3) # RGB image >>> attribution = np.random.rand(224, 224) # Saliency map >>> >>> # Create visualization >>> overlay = show_cam_on_image(image, attribution) >>> overlay.shape (224, 224, 3) """ if not HAS_CV2: # Fallback implementation without cv2 return _show_cam_fallback(img, mask, image_weight) if colormap is None: colormap = cv2.COLORMAP_JET # Ensure image is RGB format with 3 channels if img.ndim == 2: img = np.stack([img] * 3, axis=-1) elif img.shape[-1] == 1: img = np.concatenate([img] * 3, axis=-1) # Validate inputs if img.max() > 1.0 + 1e-6: raise ValueError( f"Image values should be in [0, 1], got max={img.max():.4f}. " "Normalize with: img = (img - img.min()) / (img.max() - img.min())" ) # Normalize mask to [0, 1] mask = mask.astype(np.float32) if mask.max() > mask.min(): mask = (mask - mask.min()) / (mask.max() - mask.min()) # Apply colormap to mask heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) if use_rgb: heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = np.float32(heatmap) / 255 # Blend image and heatmap cam = (1 - image_weight) * heatmap + image_weight * img cam = cam / cam.max() # Normalize return np.uint8(255 * cam)
def _show_cam_fallback( img: np.ndarray, mask: np.ndarray, image_weight: float = 0.5, ) -> np.ndarray: """Fallback implementation of show_cam_on_image without OpenCV. Uses matplotlib colormaps instead of cv2.applyColorMap. """ try: from matplotlib import cm except ImportError: raise ImportError( "Either cv2 (opencv-python) or matplotlib is required for " "visualization. Install with: pip install opencv-python matplotlib" ) # Ensure image is RGB format with 3 channels if img.ndim == 2: img = np.stack([img] * 3, axis=-1) elif img.shape[-1] == 1: img = np.concatenate([img] * 3, axis=-1) # Normalize mask to [0, 1] mask = mask.astype(np.float32) if mask.max() > mask.min(): mask = (mask - mask.min()) / (mask.max() - mask.min()) # Apply jet colormap heatmap = cm.jet(mask)[:, :, :3] # Remove alpha channel # Blend image and heatmap cam = (1 - image_weight) * heatmap + image_weight * img cam = cam / cam.max() # Normalize return np.uint8(255 * cam)
[docs]def interpolate_attribution_map( attribution: np.ndarray, target_size: Tuple[int, int], mode: str = "bilinear", ) -> np.ndarray: """Interpolate attribution map to target size. This is useful for models where the attribution is computed at a lower resolution (e.g., 14x14 patch grid for ViT-B/16) and needs to be upsampled to the original image resolution (e.g., 224x224). Args: attribution: Attribution map as numpy array or torch tensor. Shape can be (H, W) or (B, H, W) or (1, 1, H, W). target_size: Target (height, width) for interpolation. mode: Interpolation mode. Options: "bilinear", "nearest". Default is "bilinear" for smooth gradients. Returns: Interpolated attribution map with shape (target_h, target_w). Examples: >>> # For ViT-B/16 with 14x14 patch grid >>> attr_patches = np.random.rand(14, 14) >>> attr_full = interpolate_attribution_map(attr_patches, (224, 224)) >>> attr_full.shape (224, 224) """ import torch import torch.nn.functional as F # Convert to tensor if needed is_numpy = isinstance(attribution, np.ndarray) if is_numpy: attribution = torch.from_numpy(attribution).float() # Ensure 4D tensor: (B, C, H, W) while attribution.dim() < 4: attribution = attribution.unsqueeze(0) # Interpolate interpolated = F.interpolate( attribution, size=target_size, mode=mode, align_corners=False if mode == "bilinear" else None, ) # Remove batch and channel dims, convert back to numpy result = interpolated.squeeze() if is_numpy: result = result.numpy() return result
[docs]def normalize_attribution( attribution: Union[np.ndarray, "torch.Tensor"], method: str = "minmax", ) -> np.ndarray: """Normalize attribution values for visualization. Args: attribution: Raw attribution values. method: Normalization method. Options: - "minmax": Scale to [0, 1] using min-max normalization - "abs_max": Scale by absolute maximum, keeping sign - "percentile": Clip to [5, 95] percentile then normalize Returns: Normalized attribution as numpy array in [0, 1]. """ import torch if isinstance(attribution, torch.Tensor): attribution = attribution.detach().cpu().numpy() attr = attribution.astype(np.float32) if method == "minmax": if attr.max() > attr.min(): return (attr - attr.min()) / (attr.max() - attr.min()) return np.zeros_like(attr) elif method == "abs_max": abs_max = np.abs(attr).max() if abs_max > 0: return (attr / abs_max + 1) / 2 # Map [-1, 1] to [0, 1] return np.zeros_like(attr) + 0.5 elif method == "percentile": p5, p95 = np.percentile(attr, [5, 95]) attr = np.clip(attr, p5, p95) if p95 > p5: return (attr - p5) / (p95 - p5) return np.zeros_like(attr) else: raise ValueError(f"Unknown normalization method: {method}")
[docs]def visualize_image_attr( image: Union[np.ndarray, "torch.Tensor"], attribution: Union[np.ndarray, "torch.Tensor"], normalize: bool = True, interpolate: bool = True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Generate visualization components from an image and attribution map. This is a convenience function that prepares image and attribution for visualization, handling common format conversions, interpolation, and creating an overlay. Works with any image-based model (CNN, ViT, etc.). Args: image: Input image as numpy array or torch tensor. Accepted shapes: [H, W], [H, W, C], [C, H, W]. Values can be in any range (will be normalized to [0, 1]). attribution: Attribution map as numpy array or torch tensor. Shape should be [H, W]. If different from image size, will be interpolated to match when interpolate=True. normalize: If True, normalize attribution to [0, 1] range. Default is True. interpolate: If True, interpolate attribution map to match image dimensions if they differ. Default is True. Returns: Tuple of (image, attribution_map, overlay) where: - image: Normalized image as numpy array [H, W] or [H, W, C] in [0, 1] - attribution_map: Attribution as numpy array [H, W] in [0, 1] - overlay: Attribution overlay on image as numpy array [H, W, 3] in [0, 255] Examples: >>> from pyhealth.interpret.methods import CheferRelevance >>> from pyhealth.interpret.utils import visualize_image_attr >>> >>> # Compute attribution with interpreter >>> interpreter = CheferRelevance(model) >>> attr_map = interpreter.get_vit_attribution_map(**batch) >>> >>> # Generate visualization (auto-interpolates to image size) >>> image, attr_display, overlay = visualize_image_attr( ... image=batch["image"][0], ... attribution=attr_map[0, 0], ... interpolate=True, ... ) >>> >>> # Display >>> import matplotlib.pyplot as plt >>> plt.imshow(overlay) >>> plt.savefig("attribution.png") """ import torch # Convert image to numpy if isinstance(image, torch.Tensor): image = image.detach().cpu().numpy() # Handle channel dimension - convert [C, H, W] to [H, W, C] if image.ndim == 3 and image.shape[0] in [1, 3]: image = np.transpose(image, (1, 2, 0)) # Handle single-channel images if image.ndim == 3 and image.shape[-1] == 1: image = image.squeeze(-1) # Normalize image to [0, 1] image = image.astype(np.float32) image = (image - image.min()) / (image.max() - image.min() + 1e-8) # Get image spatial dimensions img_h, img_w = image.shape[:2] # Convert attribution to numpy if isinstance(attribution, torch.Tensor): attribution = attribution.detach().cpu().numpy() # Ensure attribution is 2D attribution = np.squeeze(attribution) # Interpolate attribution to match image size if needed if interpolate and attribution.shape != (img_h, img_w): attribution = interpolate_attribution_map(attribution, (img_h, img_w)) # Normalize attribution if requested if normalize: attribution = normalize_attribution(attribution) # Create overlay if image.ndim == 2: image_rgb = np.stack([image] * 3, axis=-1) else: image_rgb = image overlay = show_cam_on_image(image_rgb, attribution) return image, attribution, overlay