Source code for pyhealth.interpret.methods.basic_gradient

import torch
import numpy as np
from typing import Dict
from pyhealth.interpret.methods.base_interpreter import BaseInterpreter

[docs]class BasicGradientSaliencyMaps(BaseInterpreter): """Compute gradient-based saliency maps for image classification models. This class generates saliency maps by computing gradients of model predictions with respect to input pixels, highlighting regions that most influenced the prediction. The saliency is computed by taking the maximum absolute gradient across color channels. The method is particularly useful for: - **Clinical interpretability**: Understanding which image regions drove a diagnosis - **Model debugging**: Verifying the model focuses on clinically relevant features - **Trust and transparency**: Providing visual explanations for predictions - **Error analysis**: Comparing saliency maps for correct vs. incorrect predictions Algorithm: 1. Forward pass: Compute model predictions for input batch 2. Target selection: Use predicted class (argmax of probabilities) 3. Backward pass: Compute gradients with respect to input pixels 4. Saliency map: Take absolute value and max across color channels Mathematical formula: saliency(x, y) = max_c |∂score_predicted / ∂pixel_{x,y,c}| where c iterates over color channels (RGB or grayscale) Examples: Basic usage with a batch:: from pyhealth.interpret.methods.basic_gradient import BasicGradientSaliencyMaps import matplotlib.pyplot as plt # Create batch batch = { 'image': torch.randn(2, 3, 224, 224), 'disease': torch.tensor([0, 1]) } # Compute saliency maps saliency = BasicGradientSaliencyMaps(model, input_batch=batch) # Visualize saliency.visualize_saliency_map( plt, image_index=0, title="Saliency Map", id2label={0: "Normal", 1: "COVID"} ) Using the attribute() interface:: # Initialize without batch saliency = BasicGradientSaliencyMaps(model) # Compute attributions for new data attributions = saliency.attribute(**batch) # Returns: {'image': tensor with saliency maps} # Save to batch history attributions = saliency.attribute(save_to_batch=True, **batch) Note: - Do not use within ``torch.no_grad()`` context as gradients are required - Works with any PyHealth image classification model - For best results, normalize input images consistently with training See Also: - ``examples/ChestXrayClassificationWithSaliency.ipynb``: Complete tutorial - :class:`~pyhealth.interpret.methods.IntegratedGradients`: Alternative attribution method """ def __init__(self, model, input_batch=None, image_key='image', label_key='disease'): """Initialize the saliency map generator. Args: model: PyHealth model with forward method expecting image and disease kwargs input_batch: Optional batch of data as dictionary, list, or tensor. If None, use attribute() method to compute saliency maps. image_key: Key for accessing images in samples (default: 'image') label_key: Key for accessing labels in samples (default: 'disease') """ # Validate that input_batch is either a dictionary, list, tensor, or None if input_batch is not None and not isinstance(input_batch, (dict, list, torch.Tensor)): raise ValueError("input_batch must be a dictionary, list, tensor, or None") # Call parent constructor super().__init__(model) # Store additional attributes specific to this class self.Model = model # Keep for backward compatibility self.Input_batch = input_batch self.Image_key = image_key self.Label_key = label_key self.Batch_saliency_maps = [] # Compute saliency maps if input_batch was provided if input_batch is not None: self._compute_saliency_maps()
[docs] def attribute(self, save_to_batch=False, **data) -> Dict[str, torch.Tensor]: """Compute attribution scores for input features. This method implements the BaseInterpreter interface by computing gradient-based saliency maps for the input images. Args: save_to_batch: If True, save results to Batch_saliency_maps (default: False) **data: Input data dictionary containing 'image' and optionally 'disease' keys Returns: Dict[str, torch.Tensor]: Dictionary with 'saliency' key mapping to saliency map tensor """ # Process the batch if isinstance(data, (list, torch.Tensor)): batch_dict = { self.Image_key: data[0] if isinstance(data, list) else data, self.Label_key: data[1] if isinstance(data, list) else None } else: batch_dict = data # Prepare input tensors imgs = batch_dict[self.Image_key] batch_images = imgs.clone().detach().requires_grad_() batch_labels = batch_dict.get(self.Label_key, None) # Get model predictions output = self.model(image=batch_images, disease=batch_labels) y_prob = output['y_prob'] target_class = y_prob.argmax(dim=1) scores = y_prob.gather(1, target_class.unsqueeze(1)).squeeze() # Compute gradients self.model.zero_grad() scores.sum().backward() # Process gradients into saliency map sal = batch_images.grad.abs() sal, _ = torch.max(sal, dim=1) # Max across channels # Save to Batch_saliency_maps if requested if save_to_batch: result = { 'saliency': sal, 'image': batch_images, 'label': batch_labels } self.Batch_saliency_maps.append(result) return {self.Image_key: sal}
[docs] def get_gradient_saliency_maps(self): """Retrieve gradient saliency maps. Returns: list: Batch saliency map results """ return self.Batch_saliency_maps
def _compute_saliency_maps(self): """Compute gradient saliency maps for input batch.""" if self.Input_batch is None: return # Nothing to compute self.Model.eval() if isinstance(self.Input_batch, (list, torch.Tensor)): # If input_batch is a list or tensor, wrap it in a dictionary batch_dict = { self.Image_key: self.Input_batch[0] if isinstance(self.Input_batch, list) else self.Input_batch, self.Label_key: self.Input_batch[1] if isinstance(self.Input_batch, list) else None } self._process_batch(batch_dict) else: # Assume it's already a dictionary self._process_batch(self.Input_batch) def _process_batch(self, batch): """Process a batch of inputs to generate saliency maps. This method wraps the attribute() method to maintain backward compatibility with the original batch processing API. Args: batch: Dictionary containing image and label tensors with keys matching self.Image_key and self.Label_key """ # Use attribute method to compute saliency attributions = self.attribute(**batch) # Extract the saliency map from attributions sal = attributions[self.Image_key] # Prepare input tensors for storing complete results imgs = batch[self.Image_key] batch_images = imgs.clone().detach().requires_grad_() batch_labels = batch[self.Label_key] # Store results in the original format for backward compatibility result = { 'saliency': sal, 'image': batch_images, 'label': batch_labels } self.Batch_saliency_maps.append(result)
[docs] def visualize_saliency_map(self, plt, *, image_index, title=None, id2label=None, alpha=0.3): """Display an image with its saliency map overlay. Args: plt: matplotlib.pyplot instance image_index: Index of image within batch title: Optional title for the plot id2label: Optional dictionary mapping class indices to labels alpha: Transparency of saliency overlay (default: 0.3) """ if plt is None: import matplotlib.pyplot as plt # Check if input_batch is available if self.Input_batch is None: raise ValueError("Cannot visualize: no input_batch was provided during initialization") # Get image from input batch img_tensor = self.Input_batch[self.Image_key][image_index] true_label = self.Input_batch[self.Label_key][image_index].item() # Ensure input is a tensor with correct shape if img_tensor.dim() == 3: img_tensor = img_tensor.unsqueeze(0) if img_tensor.dim() != 4: raise ValueError(f"Expected 4D tensor (batch, channels, height, width), got shape {img_tensor.shape}") # Compute saliency img_tensor = img_tensor.clone().requires_grad_(True) # Create a dummy label tensor of zeros for the forward pass dummy_label = torch.zeros(img_tensor.size(0), dtype=torch.long, device=img_tensor.device) output = self.Model(image=img_tensor, disease=dummy_label) pred_class = torch.argmax(output['y_prob']).item() # Backward pass self.Model.zero_grad() output['y_prob'][:, pred_class].backward() # Get saliency map saliency = torch.max(img_tensor.grad.abs(), dim=1)[0] # Add both true label and predicted class to title if id2label is not None: true_label_str = id2label[true_label] pred_label_str = id2label[pred_class] if title is None: title = f"True: {true_label_str}, Predicted: {pred_label_str}" else: title = f"{title} - True: {true_label_str}, Predicted: {pred_label_str}" # Convert image to numpy for display if img_tensor.dim() == 4: img_tensor = img_tensor[0] img_np = img_tensor.detach().cpu().numpy() if img_np.shape[0] in [1, 3]: # CHW to HWC img_np = np.transpose(img_np, (1, 2, 0)) if img_np.shape[-1] == 1: img_np = img_np.squeeze(-1) # Convert saliency to numpy if saliency.dim() > 2: saliency = saliency[0] saliency_np = saliency.detach().cpu().numpy() # Create visualization plt.figure(figsize=(15, 7)) plt.axis('off') plt.imshow(img_np, cmap='gray') plt.imshow(saliency_np, cmap='hot', alpha=alpha) if title: plt.title(title) plt.show()