pyhealth.interpret.methods.basic_gradient#

Overview#

The BasicGradientSaliencyMaps method computes gradient-based saliency maps for PyHealth’s image classification models. This helps identify which regions of medical images most influenced the model’s prediction by visualizing gradients of model outputs with respect to input pixels.

For a complete working example, see: examples/ChestXrayClassificationWithSaliency.ipynb

API Reference#

class pyhealth.interpret.methods.basic_gradient.BasicGradientSaliencyMaps(model, input_batch=None, image_key='image', label_key='disease')[source]#

Bases: BaseInterpreter

Compute gradient-based saliency maps for image classification models.

This class generates saliency maps by computing gradients of model predictions with respect to input pixels, highlighting regions that most influenced the prediction. The saliency is computed by taking the maximum absolute gradient across color channels.

The method is particularly useful for:

  • Clinical interpretability: Understanding which image regions drove a diagnosis

  • Model debugging: Verifying the model focuses on clinically relevant features

  • Trust and transparency: Providing visual explanations for predictions

  • Error analysis: Comparing saliency maps for correct vs. incorrect predictions

Algorithm:
  1. Forward pass: Compute model predictions for input batch

  2. Target selection: Use predicted class (argmax of probabilities)

  3. Backward pass: Compute gradients with respect to input pixels

  4. Saliency map: Take absolute value and max across color channels

Mathematical formula:

saliency(x, y) = max_c |∂score_predicted / ∂pixel_{x,y,c}| where c iterates over color channels (RGB or grayscale)

Examples

Basic usage with a batch:

from pyhealth.interpret.methods.basic_gradient import BasicGradientSaliencyMaps
import matplotlib.pyplot as plt

# Create batch
batch = {
    'image': torch.randn(2, 3, 224, 224),
    'disease': torch.tensor([0, 1])
}

# Compute saliency maps
saliency = BasicGradientSaliencyMaps(model, input_batch=batch)

# Visualize
saliency.visualize_saliency_map(
    plt,
    image_index=0,
    title="Saliency Map",
    id2label={0: "Normal", 1: "COVID"}
)

Using the attribute() interface:

# Initialize without batch
saliency = BasicGradientSaliencyMaps(model)

# Compute attributions for new data
attributions = saliency.attribute(**batch)
# Returns: {'image': tensor with saliency maps}

# Save to batch history
attributions = saliency.attribute(save_to_batch=True, **batch)

Note

  • Do not use within torch.no_grad() context as gradients are required

  • Works with any PyHealth image classification model

  • For best results, normalize input images consistently with training

See also

  • examples/ChestXrayClassificationWithSaliency.ipynb: Complete tutorial

  • IntegratedGradients: Alternative attribution method

attribute(save_to_batch=False, **data)[source]#

Compute attribution scores for input features.

This method implements the BaseInterpreter interface by computing gradient-based saliency maps for the input images.

Parameters:
  • save_to_batch – If True, save results to Batch_saliency_maps (default: False)

  • **data – Input data dictionary containing ‘image’ and optionally ‘disease’ keys

Returns:

Dictionary with ‘saliency’ key mapping to saliency map tensor

Return type:

Dict[str, torch.Tensor]

get_gradient_saliency_maps()[source]#

Retrieve gradient saliency maps.

Returns:

Batch saliency map results

Return type:

list

visualize_saliency_map(plt, *, image_index, title=None, id2label=None, alpha=0.3)[source]#

Display an image with its saliency map overlay.

Parameters:
  • plt – matplotlib.pyplot instance

  • image_index – Index of image within batch

  • title – Optional title for the plot

  • id2label – Optional dictionary mapping class indices to labels

  • alpha – Transparency of saliency overlay (default: 0.3)