pyhealth.interpret.methods.basic_gradient#
Overview#
The BasicGradientSaliencyMaps method computes gradient-based saliency maps for PyHealth’s image classification models. This helps identify which regions of medical images most influenced the model’s prediction by visualizing gradients of model outputs with respect to input pixels.
For a complete working example, see:
examples/ChestXrayClassificationWithSaliency.ipynb
API Reference#
- class pyhealth.interpret.methods.basic_gradient.BasicGradientSaliencyMaps(model, input_batch=None, image_key='image', label_key='disease')[source]#
Bases:
BaseInterpreterCompute gradient-based saliency maps for image classification models.
This class generates saliency maps by computing gradients of model predictions with respect to input pixels, highlighting regions that most influenced the prediction. The saliency is computed by taking the maximum absolute gradient across color channels.
The method is particularly useful for:
Clinical interpretability: Understanding which image regions drove a diagnosis
Model debugging: Verifying the model focuses on clinically relevant features
Trust and transparency: Providing visual explanations for predictions
Error analysis: Comparing saliency maps for correct vs. incorrect predictions
- Algorithm:
Forward pass: Compute model predictions for input batch
Target selection: Use predicted class (argmax of probabilities)
Backward pass: Compute gradients with respect to input pixels
Saliency map: Take absolute value and max across color channels
- Mathematical formula:
saliency(x, y) = max_c |∂score_predicted / ∂pixel_{x,y,c}| where c iterates over color channels (RGB or grayscale)
Examples
Basic usage with a batch:
from pyhealth.interpret.methods.basic_gradient import BasicGradientSaliencyMaps import matplotlib.pyplot as plt # Create batch batch = { 'image': torch.randn(2, 3, 224, 224), 'disease': torch.tensor([0, 1]) } # Compute saliency maps saliency = BasicGradientSaliencyMaps(model, input_batch=batch) # Visualize saliency.visualize_saliency_map( plt, image_index=0, title="Saliency Map", id2label={0: "Normal", 1: "COVID"} )
Using the attribute() interface:
# Initialize without batch saliency = BasicGradientSaliencyMaps(model) # Compute attributions for new data attributions = saliency.attribute(**batch) # Returns: {'image': tensor with saliency maps} # Save to batch history attributions = saliency.attribute(save_to_batch=True, **batch)
Note
Do not use within
torch.no_grad()context as gradients are requiredWorks with any PyHealth image classification model
For best results, normalize input images consistently with training
See also
examples/ChestXrayClassificationWithSaliency.ipynb: Complete tutorialIntegratedGradients: Alternative attribution method
- attribute(save_to_batch=False, **data)[source]#
Compute attribution scores for input features.
This method implements the BaseInterpreter interface by computing gradient-based saliency maps for the input images.
- Parameters:
save_to_batch – If True, save results to Batch_saliency_maps (default: False)
**data – Input data dictionary containing ‘image’ and optionally ‘disease’ keys
- Returns:
Dictionary with ‘saliency’ key mapping to saliency map tensor
- Return type:
Dict[str, torch.Tensor]
- get_gradient_saliency_maps()[source]#
Retrieve gradient saliency maps.
- Returns:
Batch saliency map results
- Return type:
- visualize_saliency_map(plt, *, image_index, title=None, id2label=None, alpha=0.3)[source]#
Display an image with its saliency map overlay.
- Parameters:
plt – matplotlib.pyplot instance
image_index – Index of image within batch
title – Optional title for the plot
id2label – Optional dictionary mapping class indices to labels
alpha – Transparency of saliency overlay (default: 0.3)