pyhealth.metrics.interpretability#

Interpretability metrics evaluate the faithfulness of feature attribution methods by measuring how model predictions change when important features are removed or retained.

Evaluator#

class pyhealth.metrics.interpretability.evaluator.Evaluator(model, percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', sample_filter=None, positive_threshold=None)[source]#

Bases: object

High-level interface for evaluating interpretations.

This class provides a convenient API for computing multiple interpretability metrics at once, both on individual batches and across entire datasets.

Parameters:
  • model (BaseModel) – PyHealth BaseModel to evaluate

  • percentages (List[float]) – List of percentages to evaluate at. Default: [1, 5, 10, 20, 50].

  • ablation_strategy (str) – How to ablate features. Options: - ‘zero’: Set ablated features to 0 - ‘mean’: Set ablated features to feature mean across batch - ‘noise’: Add Gaussian noise to ablated features Default: ‘zero’.

  • sample_filter (Optional[Callable[[Tensor, str], Tensor]]) – A callable that classifies each sample for evaluation. Signature: (class_probs, classifier_type) -> sample_classes where class_probs has shape (batch_size,) and contains the class probability used for filtering. For binary single-logit models, this is P(class=1). For multiclass/multilabel models, this is the gathered target-class probability. sample_classes is a tensor of SampleClass values: - SampleClass.POSITIVE: evaluate with attributions as-is - SampleClass.NEGATIVE: evaluate with negated attributions - SampleClass.IGNORE: exclude from evaluation If None, uses default_sample_filter.

  • positive_threshold (Optional[float]) –

    Deprecated since version This: parameter is deprecated and will be removed in a future release. Use sample_filter with threshold_sample_filter() instead. Threshold for positive class in binary classification. Default: None.

Examples

>>> from pyhealth.models import StageNet
>>> from pyhealth.metrics.interpretability import Evaluator
>>> from pyhealth.metrics.interpretability.utils import (
...     SampleClass,
...     threshold_sample_filter,
... )
>>>
>>> # Initialize evaluator with default filter
>>> evaluator = Evaluator(model)
>>>
>>> # Initialize with custom filter that ignores low-confidence
>>> def confident_filter(class_probs, classifier_type):
...     batch_size = class_probs.shape[0]
...     result = torch.full(
...         (batch_size,), SampleClass.POSITIVE,
...         dtype=torch.long, device=class_probs.device,
...     )
...     result[class_probs < 0.6] = SampleClass.IGNORE
...     return result
>>> evaluator = Evaluator(model, sample_filter=confident_filter)
>>>
>>> # Evaluate on a single batch
>>> inputs = {'conditions': torch.randn(32, 50)}
>>> attributions = {'conditions': torch.randn(32, 50)}
>>> batch_results = evaluator.evaluate(inputs, attributions)
>>> for metric, (scores, mask) in batch_results.items():
>>>     print(f"{metric}: {scores[mask].mean():.4f}")
>>>
>>> # Evaluate across entire dataset with an attribution method
>>> from pyhealth.interpret.methods import IntegratedGradients
>>> ig = IntegratedGradients(model)
>>> results = evaluator.evaluate_attribution(test_loader, ig)
>>> print(f"Comprehensiveness: {results['comprehensiveness']:.4f}")
>>> print(f"Sufficiency: {results['sufficiency']:.4f}")
evaluate(inputs, attributions, metrics=['comprehensiveness', 'sufficiency'], return_per_percentage=False, debug=False)[source]#

Evaluate multiple metrics at once.

Parameters:
  • inputs (Dict[str, Tensor]) – Model inputs (batch)

  • attributions (Dict[str, Tensor]) – Attribution scores (batch)

  • metrics (List[str]) – List of metrics to compute. Options: [“comprehensiveness”, “sufficiency”]

  • return_per_percentage (bool) – If True, return per-percentage breakdown. If False (default), return averaged scores.

  • debug (bool) – If True, print detailed debug information (only used when return_per_percentage=True)

Returns:

Dictionary mapping metric name -> (scores, valid_mask)
  • scores: Raw metric scores for all samples

  • valid_mask: Binary mask indicating valid samples to average

If return_per_percentage=True:

Dictionary mapping metric name -> Dict[percentage -> scores] Example: {‘comprehensiveness’: {10: tensor(…), 20: …}}

Return type:

If return_per_percentage=False (default)

Note

For binary classifiers, all samples are evaluated (both positive and negative predictions). Use: scores[valid_mask].mean()

Examples

>>> # Default: averaged scores
>>> results = evaluator.evaluate(inputs, attributions)
>>> comp_scores, valid_mask = results['comprehensiveness']
>>> print(f"Mean: {comp_scores[valid_mask].mean():.4f}")
>>>
>>> # Per-percentage breakdown
>>> detailed = evaluator.evaluate(
...     inputs, attributions, return_per_percentage=True
... )
>>> comp_10_pct = detailed['comprehensiveness'][10]
>>> suff_20_pct = detailed['sufficiency'][20]
evaluate_attribution(dataloader, method, metrics=['comprehensiveness', 'sufficiency'])[source]#

Evaluate an attribution method across an entire dataset.

This method computes the average faithfulness metrics for an attribution approach across all batches in a dataloader. It automatically computes attributions using the provided method and evaluates them.

Parameters:
  • dataloader – PyTorch DataLoader with test/validation data. Should yield batches compatible with the model.

  • method – Attribution method instance (e.g., IntegratedGradients). Must implement the BaseInterpreter interface with an attribute(**data) method.

  • metrics (List[str]) – List of metrics to compute. Options: [“comprehensiveness”, “sufficiency”] Default: both metrics.

Return type:

Dict[str, float]

Returns:

Dictionary mapping metric names to their average scores across the entire dataset. Samples marked IGNORE by the configured sample_filter are excluded from the average.

Example: {‘comprehensiveness’: 0.345, ‘sufficiency’: 0.123}

Note

For binary classifiers, both positive and negative samples can be evaluated. Negative samples are handled by negating the attribution scores before top-feature selection, which makes the probability drop equivalent to the drop in confidence for class 0. Use sample_filter to include or exclude whichever subsets you want in the dataset average.

Examples

>>> from pyhealth.interpret.methods import IntegratedGradients
>>> from pyhealth.metrics.interpretability import Evaluator
>>>
>>> # Initialize evaluator and attribution method
>>> evaluator = Evaluator(model)
>>> ig = IntegratedGradients(model, use_embeddings=True)
>>>
>>> # Evaluate across test set
>>> results = evaluator.evaluate_attribution(
...     test_loader, ig, metrics=["comprehensiveness"]
... )
>>> print(f"Comprehensiveness: {results['comprehensiveness']:.4f}")
>>>
>>> # Compare multiple methods
>>> from pyhealth.interpret.methods import CheferRelevance
>>> chefer = CheferRelevance(model)
>>> ig_results = evaluator.evaluate_attribution(test_loader, ig)
>>> chefer_results = evaluator.evaluate_attribution(
...     test_loader, chefer
... )
>>> print("Method Comparison:")
>>> print(f"  IG Comp: {ig_results['comprehensiveness']:.4f}")
>>> print(
...     f"  Chefer Comp: "
...     f"{chefer_results['comprehensiveness']:.4f}"
... )

Functional API#

pyhealth.metrics.interpretability.evaluator.evaluate_attribution(model, dataloader, method, metrics=['comprehensiveness', 'sufficiency'], percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', sample_filter=None, positive_threshold=None)[source]#

Evaluate an attribution method across a dataset (functional API).

This is a convenience function that wraps the Evaluator class for simple one-off evaluations. For multiple evaluations with the same configuration, consider using the Evaluator class directly for better efficiency.

Parameters:
  • model (BaseModel) – PyHealth BaseModel to evaluate

  • dataloader – PyTorch DataLoader with test/validation data

  • method – Attribution method instance (e.g., IntegratedGradients). Must implement the BaseInterpreter interface.

  • metrics (List[str]) – List of metrics to compute. Options: [“comprehensiveness”, “sufficiency”] Default: both metrics.

  • percentages (List[float]) – List of percentages to evaluate at. Default: [1, 5, 10, 20, 50].

  • ablation_strategy (str) – How to ablate features. Options: - ‘zero’: Set ablated features to 0 - ‘mean’: Set ablated features to feature mean across batch - ‘noise’: Add Gaussian noise to ablated features Default: ‘zero’.

  • sample_filter (Optional[Callable[[Tensor, str], Tensor]]) – A callable that classifies each sample for evaluation. Signature: (class_probs, classifier_type) -> sample_classes where class_probs has shape (batch_size,) and contains the probability for the predicted class (sigmoid/softmax output with target class already applied), and sample_classes is a tensor of SampleClass values: - SampleClass.POSITIVE: evaluate with attributions as-is - SampleClass.NEGATIVE: evaluate with negated attributions - SampleClass.IGNORE: exclude from evaluation If None, uses default_sample_filter.

  • positive_threshold (Optional[float]) –

    Deprecated since version This: parameter is deprecated and will be removed in a future release. Use sample_filter with threshold_sample_filter() instead. Threshold for positive class in binary classification. Default: None.

Return type:

Dict[str, float]

Returns:

Dictionary mapping metric names to their average scores across the entire dataset. Averaging uses mask-based filtering to exclude IGNORE samples.

Example: {‘comprehensiveness’: 0.345, ‘sufficiency’: 0.123}

Examples

>>> from pyhealth.interpret.methods import IntegratedGradients
>>> from pyhealth.metrics.interpretability import (
...     evaluate_attribution
... )
>>> from pyhealth.metrics.interpretability.utils import SampleClass
>>>
>>> # Simple one-off evaluation
>>> ig = IntegratedGradients(model, use_embeddings=True)
>>> results = evaluate_attribution(
...     model, test_loader, ig,
...     metrics=["comprehensiveness"],
...     percentages=[10, 20, 50]
... )
>>> print(f"Comprehensiveness: {results['comprehensiveness']:.4f}")
>>>
>>> # Custom filter to ignore uncertain predictions
>>> def ignore_uncertain(class_probs, classifier_type):
...     batch_size = class_probs.shape[0]
...     result = torch.full(
...         (batch_size,), SampleClass.POSITIVE,
...         dtype=torch.long, device=class_probs.device,
...     )
...     result[class_probs < 0.7] = SampleClass.IGNORE
...     return result
>>> results = evaluate_attribution(
...     model, test_loader, ig,
...     sample_filter=ignore_uncertain,
... )
>>>
>>> # For comparing multiple methods efficiently, use Evaluator:
>>> from pyhealth.metrics.interpretability import Evaluator
>>> evaluator = Evaluator(model, percentages=[10, 20, 50])
>>> ig_results = evaluator.evaluate_attribution(test_loader, ig)
>>> chefer_results = evaluator.evaluate_attribution(
...     test_loader, chefer
... )

Removal-Based Metrics#

Base Class#

class pyhealth.metrics.interpretability.base.RemovalBasedMetric(model, percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', *, sample_filter)[source]#

Bases: ABC

Abstract base class for removal-based interpretability metrics.

This class provides common functionality for computing faithfulness metrics by removing or retaining features based on their importance scores.

Parameters:
  • model (BaseModel) – PyHealth BaseModel that accepts **kwargs and returns dict with ‘y_prob’ or ‘logit’.

  • percentages (List[float]) – List of percentages to evaluate at. Default: [1, 5, 10, 20, 50].

  • ablation_strategy (str) – How to ablate features. Options: - ‘zero’: Set ablated features to 0 - ‘mean’: Set ablated features to feature mean across batch - ‘noise’: Add Gaussian noise to ablated features Default: ‘zero’.

  • sample_filter (Callable[[Tensor, str], Tensor]) – A callable that classifies each sample for evaluation. Signature: (class_probs, classifier_type) -> sample_classes where class_probs has shape (batch_size,) and contains the probability for the predicted class (sigmoid/softmax output with target class already applied), and sample_classes is a tensor of SampleClass values. - SampleClass.POSITIVE: evaluate with attributions as-is - SampleClass.NEGATIVE: evaluate with negated attributions - SampleClass.IGNORE: exclude from evaluation

compute(inputs, attributions, predicted_class=None, return_per_percentage=False, debug=False)[source]#

Compute metric across percentages.

Parameters:
  • inputs (Dict[str, Tensor]) – Model inputs (batch)

  • attributions (Dict[str, Tensor]) – Attribution scores matching input shapes (batch)

  • predicted_class (Optional[Tensor]) – Optional pre-computed predicted classes (batch)

  • return_per_percentage (bool) – If True, return dict mapping percentage -> scores. If False (default), return averaged score across percentages.

  • debug (bool) – If True, print detailed probability information (only used when return_per_percentage=True)

Returns:

Tuple of (metric_scores, valid_mask):
  • metric_scores: Average scores across percentages,

    shape (batch_size,)

  • valid_mask: Binary mask indicating valid samples

If return_per_percentage=True:

Dict[float, torch.Tensor]: Maps percentage -> scores (batch_size,). For binary classifiers, negative class samples have value 0.

Return type:

If return_per_percentage=False (default)

Note

For binary classifiers, all samples are evaluated (both positive and negative predictions). For class 0 predictions, attributions are negated internally so that feature importance is measured relative to the predicted class.

Comprehensiveness#

class pyhealth.metrics.interpretability.comprehensiveness.ComprehensivenessMetric(model, percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', *, sample_filter)[source]#

Bases: RemovalBasedMetric

Comprehensiveness metric for interpretability evaluation.

Measures the drop in predicted class probability when important features are REMOVED (ablated). Higher scores indicate more faithful interpretations.

The metric is computed as:
COMP = (1/|B|) × Σ[p_c(x)(x) - p_c(x)(x x:q%)]

q∈B

Where:
  • x is the original input

  • x:q% are the top q% most important features

  • x x:q% is input with top q% features removed (ablated)

  • p_c(x)(·) is predicted probability for original predicted class

  • B is the set of percentages (default: {1, 5, 10, 20, 50})

Examples

>>> import torch
>>> from pyhealth.models import MLP
>>> from pyhealth.metrics.interpretability import (
...     ComprehensivenessMetric
... )
>>>
>>> # Assume we have a trained model
>>> model = MLP(dataset=dataset)
>>>
>>> # Initialize metric
>>> comp = ComprehensivenessMetric(model)
>>>
>>> # Prepare inputs and attributions
>>> inputs = {'conditions': torch.randn(32, 50)}
>>> attributions = {'conditions': torch.randn(32, 50)}
>>>
>>> # Compute metric
>>> scores, valid_mask = comp.compute(inputs, attributions)
>>> print(f"Mean comprehensiveness: {scores[valid_mask].mean():.3f}")
Mean comprehensiveness: 0.234
>>>
>>> # Get detailed scores per percentage
>>> detailed = comp.compute(
...     inputs, attributions, return_per_percentage=True
... )
>>> for pct, scores in detailed.items():
...     print(f"  {pct}%: {scores.mean():.3f}")
  1%: 0.045
  5%: 0.123
  10%: 0.234
  20%: 0.345
  50%: 0.456
compute(inputs, attributions, predicted_class=None, return_per_percentage=False, debug=False)#

Compute metric across percentages.

Parameters:
  • inputs (Dict[str, Tensor]) – Model inputs (batch)

  • attributions (Dict[str, Tensor]) – Attribution scores matching input shapes (batch)

  • predicted_class (Optional[Tensor]) – Optional pre-computed predicted classes (batch)

  • return_per_percentage (bool) – If True, return dict mapping percentage -> scores. If False (default), return averaged score across percentages.

  • debug (bool) – If True, print detailed probability information (only used when return_per_percentage=True)

Returns:

Tuple of (metric_scores, valid_mask):
  • metric_scores: Average scores across percentages,

    shape (batch_size,)

  • valid_mask: Binary mask indicating valid samples

If return_per_percentage=True:

Dict[float, torch.Tensor]: Maps percentage -> scores (batch_size,). For binary classifiers, negative class samples have value 0.

Return type:

If return_per_percentage=False (default)

Note

For binary classifiers, all samples are evaluated (both positive and negative predictions). For class 0 predictions, attributions are negated internally so that feature importance is measured relative to the predicted class.

Sufficiency#

class pyhealth.metrics.interpretability.sufficiency.SufficiencyMetric(model, percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', *, sample_filter)[source]#

Bases: RemovalBasedMetric

Sufficiency metric for interpretability evaluation.

Measures the drop in predicted class probability when ONLY important features are KEPT (all others removed). Lower scores indicate more faithful interpretations.

The metric is computed as:
SUFF = (1/|B|) × Σ[p_c(x)(x) - p_c(x)(x:q%)]

q∈B

Where:
  • x is the original input

  • x:q% are the top q% most important features (all others removed)

  • p_c(x)(·) is predicted probability for original predicted class

  • B is the set of percentages (default: {1, 5, 10, 20, 50})

Examples

>>> import torch
>>> from pyhealth.models import MLP
>>> from pyhealth.metrics.interpretability import SufficiencyMetric
>>>
>>> # Assume we have a trained model
>>> model = MLP(dataset=dataset)
>>>
>>> # Initialize metric
>>> suff = SufficiencyMetric(model)
>>>
>>> # Prepare inputs and attributions
>>> inputs = {'conditions': torch.randn(32, 50)}
>>> attributions = {'conditions': torch.randn(32, 50)}
>>>
>>> # Compute metric
>>> scores, valid_mask = suff.compute(inputs, attributions)
>>> print(f"Mean sufficiency: {scores[valid_mask].mean():.3f}")
Mean sufficiency: 0.089
>>>
>>> # Get detailed scores per percentage
>>> detailed = suff.compute(
...     inputs, attributions, return_per_percentage=True
... )
>>> for pct, scores in detailed.items():
...     print(f"  {pct}%: {scores.mean():.3f}")
  1%: 0.234
  5%: 0.178
  10%: 0.089
  20%: 0.045
  50%: 0.012
compute(inputs, attributions, predicted_class=None, return_per_percentage=False, debug=False)#

Compute metric across percentages.

Parameters:
  • inputs (Dict[str, Tensor]) – Model inputs (batch)

  • attributions (Dict[str, Tensor]) – Attribution scores matching input shapes (batch)

  • predicted_class (Optional[Tensor]) – Optional pre-computed predicted classes (batch)

  • return_per_percentage (bool) – If True, return dict mapping percentage -> scores. If False (default), return averaged score across percentages.

  • debug (bool) – If True, print detailed probability information (only used when return_per_percentage=True)

Returns:

Tuple of (metric_scores, valid_mask):
  • metric_scores: Average scores across percentages,

    shape (batch_size,)

  • valid_mask: Binary mask indicating valid samples

If return_per_percentage=True:

Dict[float, torch.Tensor]: Maps percentage -> scores (batch_size,). For binary classifiers, negative class samples have value 0.

Return type:

If return_per_percentage=False (default)

Note

For binary classifiers, all samples are evaluated (both positive and negative predictions). For class 0 predictions, attributions are negated internally so that feature importance is measured relative to the predicted class.

Utility Functions#

pyhealth.metrics.interpretability.utils.get_model_predictions(model, inputs, classifier_type, sample_filter=None, sample_class=None, target_class_idx=None)[source]#

Get model predictions, probabilities, and class-specific probabilities.

Parameters:
  • model (BaseModel) – PyHealth BaseModel that returns dict with ‘y_prob’ or ‘logit’

  • inputs (Dict[str, Tensor]) – Model inputs dict

  • classifier_type (str) – One of ‘binary’, ‘multiclass’, ‘multilabel’, ‘unknown’

  • target_class_idx (Optional[Tensor]) – (Optional) Pre-computed target class indices, this would ensure ablated runs are consistent with original predictions. If None, will compute from model outputs.

  • sample_filter (Optional[Callable[[Tensor, str], Tensor]]) – A callable that classifies each sample for evaluation. Signature: (class_probs, classifier_type) -> sample_classes where class_probs has shape (batch_size,). For binary single-logit models this is P(class=1); otherwise it is the gathered target-class probability. sample_classes is a tensor of SampleClass values.

Returns:

  • y_prob: All class probabilities
    • Binary: shape (batch_size, 1), values are P(class=1)

    • Multiclass: shape (batch_size, num_classes)

  • target_class_idx: Target class indices, shape (batch_size,)

  • sample_classes: SampleClass values for each sample, shape (batch_size,)

Return type:

Tuple of (y_prob, target_class_idx, sample_classes)