pyhealth.metrics.interpretability#
Interpretability metrics evaluate the faithfulness of feature attribution methods by measuring how model predictions change when important features are removed or retained.
Evaluator#
- class pyhealth.metrics.interpretability.evaluator.Evaluator(model, percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', sample_filter=None, positive_threshold=None)[source]#
Bases:
objectHigh-level interface for evaluating interpretations.
This class provides a convenient API for computing multiple interpretability metrics at once, both on individual batches and across entire datasets.
- Parameters:
model (
BaseModel) – PyHealth BaseModel to evaluatepercentages (
List[float]) – List of percentages to evaluate at. Default: [1, 5, 10, 20, 50].ablation_strategy (
str) – How to ablate features. Options: - ‘zero’: Set ablated features to 0 - ‘mean’: Set ablated features to feature mean across batch - ‘noise’: Add Gaussian noise to ablated features Default: ‘zero’.sample_filter (
Optional[Callable[[Tensor,str],Tensor]]) – A callable that classifies each sample for evaluation. Signature: (class_probs, classifier_type) -> sample_classes where class_probs has shape (batch_size,) and contains the class probability used for filtering. For binary single-logit models, this isP(class=1). For multiclass/multilabel models, this is the gathered target-class probability.sample_classesis a tensor of SampleClass values: - SampleClass.POSITIVE: evaluate with attributions as-is - SampleClass.NEGATIVE: evaluate with negated attributions - SampleClass.IGNORE: exclude from evaluation If None, uses default_sample_filter.positive_threshold (
Optional[float]) –Deprecated since version This: parameter is deprecated and will be removed in a future release. Use
sample_filterwiththreshold_sample_filter()instead. Threshold for positive class in binary classification. Default: None.
Examples
>>> from pyhealth.models import StageNet >>> from pyhealth.metrics.interpretability import Evaluator >>> from pyhealth.metrics.interpretability.utils import ( ... SampleClass, ... threshold_sample_filter, ... ) >>> >>> # Initialize evaluator with default filter >>> evaluator = Evaluator(model) >>> >>> # Initialize with custom filter that ignores low-confidence >>> def confident_filter(class_probs, classifier_type): ... batch_size = class_probs.shape[0] ... result = torch.full( ... (batch_size,), SampleClass.POSITIVE, ... dtype=torch.long, device=class_probs.device, ... ) ... result[class_probs < 0.6] = SampleClass.IGNORE ... return result >>> evaluator = Evaluator(model, sample_filter=confident_filter) >>> >>> # Evaluate on a single batch >>> inputs = {'conditions': torch.randn(32, 50)} >>> attributions = {'conditions': torch.randn(32, 50)} >>> batch_results = evaluator.evaluate(inputs, attributions) >>> for metric, (scores, mask) in batch_results.items(): >>> print(f"{metric}: {scores[mask].mean():.4f}") >>> >>> # Evaluate across entire dataset with an attribution method >>> from pyhealth.interpret.methods import IntegratedGradients >>> ig = IntegratedGradients(model) >>> results = evaluator.evaluate_attribution(test_loader, ig) >>> print(f"Comprehensiveness: {results['comprehensiveness']:.4f}") >>> print(f"Sufficiency: {results['sufficiency']:.4f}")
- evaluate(inputs, attributions, metrics=['comprehensiveness', 'sufficiency'], return_per_percentage=False, debug=False)[source]#
Evaluate multiple metrics at once.
- Parameters:
attributions (
Dict[str,Tensor]) – Attribution scores (batch)metrics (
List[str]) – List of metrics to compute. Options: [“comprehensiveness”, “sufficiency”]return_per_percentage (
bool) – If True, return per-percentage breakdown. If False (default), return averaged scores.debug (
bool) – If True, print detailed debug information (only used when return_per_percentage=True)
- Returns:
- Dictionary mapping metric name -> (scores, valid_mask)
scores: Raw metric scores for all samples
valid_mask: Binary mask indicating valid samples to average
- If return_per_percentage=True:
Dictionary mapping metric name -> Dict[percentage -> scores] Example: {‘comprehensiveness’: {10: tensor(…), 20: …}}
- Return type:
If return_per_percentage=False (default)
Note
For binary classifiers, all samples are evaluated (both positive and negative predictions). Use: scores[valid_mask].mean()
Examples
>>> # Default: averaged scores >>> results = evaluator.evaluate(inputs, attributions) >>> comp_scores, valid_mask = results['comprehensiveness'] >>> print(f"Mean: {comp_scores[valid_mask].mean():.4f}") >>> >>> # Per-percentage breakdown >>> detailed = evaluator.evaluate( ... inputs, attributions, return_per_percentage=True ... ) >>> comp_10_pct = detailed['comprehensiveness'][10] >>> suff_20_pct = detailed['sufficiency'][20]
- evaluate_attribution(dataloader, method, metrics=['comprehensiveness', 'sufficiency'])[source]#
Evaluate an attribution method across an entire dataset.
This method computes the average faithfulness metrics for an attribution approach across all batches in a dataloader. It automatically computes attributions using the provided method and evaluates them.
- Parameters:
dataloader – PyTorch DataLoader with test/validation data. Should yield batches compatible with the model.
method – Attribution method instance (e.g., IntegratedGradients). Must implement the BaseInterpreter interface with an attribute(**data) method.
metrics (
List[str]) – List of metrics to compute. Options: [“comprehensiveness”, “sufficiency”] Default: both metrics.
- Return type:
- Returns:
Dictionary mapping metric names to their average scores across the entire dataset. Samples marked
IGNOREby the configuredsample_filterare excluded from the average.Example: {‘comprehensiveness’: 0.345, ‘sufficiency’: 0.123}
Note
For binary classifiers, both positive and negative samples can be evaluated. Negative samples are handled by negating the attribution scores before top-feature selection, which makes the probability drop equivalent to the drop in confidence for class 0. Use
sample_filterto include or exclude whichever subsets you want in the dataset average.Examples
>>> from pyhealth.interpret.methods import IntegratedGradients >>> from pyhealth.metrics.interpretability import Evaluator >>> >>> # Initialize evaluator and attribution method >>> evaluator = Evaluator(model) >>> ig = IntegratedGradients(model, use_embeddings=True) >>> >>> # Evaluate across test set >>> results = evaluator.evaluate_attribution( ... test_loader, ig, metrics=["comprehensiveness"] ... ) >>> print(f"Comprehensiveness: {results['comprehensiveness']:.4f}") >>> >>> # Compare multiple methods >>> from pyhealth.interpret.methods import CheferRelevance >>> chefer = CheferRelevance(model) >>> ig_results = evaluator.evaluate_attribution(test_loader, ig) >>> chefer_results = evaluator.evaluate_attribution( ... test_loader, chefer ... ) >>> print("Method Comparison:") >>> print(f" IG Comp: {ig_results['comprehensiveness']:.4f}") >>> print( ... f" Chefer Comp: " ... f"{chefer_results['comprehensiveness']:.4f}" ... )
Functional API#
- pyhealth.metrics.interpretability.evaluator.evaluate_attribution(model, dataloader, method, metrics=['comprehensiveness', 'sufficiency'], percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', sample_filter=None, positive_threshold=None)[source]#
Evaluate an attribution method across a dataset (functional API).
This is a convenience function that wraps the Evaluator class for simple one-off evaluations. For multiple evaluations with the same configuration, consider using the Evaluator class directly for better efficiency.
- Parameters:
model (
BaseModel) – PyHealth BaseModel to evaluatedataloader – PyTorch DataLoader with test/validation data
method – Attribution method instance (e.g., IntegratedGradients). Must implement the BaseInterpreter interface.
metrics (
List[str]) – List of metrics to compute. Options: [“comprehensiveness”, “sufficiency”] Default: both metrics.percentages (
List[float]) – List of percentages to evaluate at. Default: [1, 5, 10, 20, 50].ablation_strategy (
str) – How to ablate features. Options: - ‘zero’: Set ablated features to 0 - ‘mean’: Set ablated features to feature mean across batch - ‘noise’: Add Gaussian noise to ablated features Default: ‘zero’.sample_filter (
Optional[Callable[[Tensor,str],Tensor]]) – A callable that classifies each sample for evaluation. Signature: (class_probs, classifier_type) -> sample_classes where class_probs has shape (batch_size,) and contains the probability for the predicted class (sigmoid/softmax output with target class already applied), and sample_classes is a tensor of SampleClass values: - SampleClass.POSITIVE: evaluate with attributions as-is - SampleClass.NEGATIVE: evaluate with negated attributions - SampleClass.IGNORE: exclude from evaluation If None, uses default_sample_filter.positive_threshold (
Optional[float]) –Deprecated since version This: parameter is deprecated and will be removed in a future release. Use
sample_filterwiththreshold_sample_filter()instead. Threshold for positive class in binary classification. Default: None.
- Return type:
- Returns:
Dictionary mapping metric names to their average scores across the entire dataset. Averaging uses mask-based filtering to exclude IGNORE samples.
Example: {‘comprehensiveness’: 0.345, ‘sufficiency’: 0.123}
Examples
>>> from pyhealth.interpret.methods import IntegratedGradients >>> from pyhealth.metrics.interpretability import ( ... evaluate_attribution ... ) >>> from pyhealth.metrics.interpretability.utils import SampleClass >>> >>> # Simple one-off evaluation >>> ig = IntegratedGradients(model, use_embeddings=True) >>> results = evaluate_attribution( ... model, test_loader, ig, ... metrics=["comprehensiveness"], ... percentages=[10, 20, 50] ... ) >>> print(f"Comprehensiveness: {results['comprehensiveness']:.4f}") >>> >>> # Custom filter to ignore uncertain predictions >>> def ignore_uncertain(class_probs, classifier_type): ... batch_size = class_probs.shape[0] ... result = torch.full( ... (batch_size,), SampleClass.POSITIVE, ... dtype=torch.long, device=class_probs.device, ... ) ... result[class_probs < 0.7] = SampleClass.IGNORE ... return result >>> results = evaluate_attribution( ... model, test_loader, ig, ... sample_filter=ignore_uncertain, ... ) >>> >>> # For comparing multiple methods efficiently, use Evaluator: >>> from pyhealth.metrics.interpretability import Evaluator >>> evaluator = Evaluator(model, percentages=[10, 20, 50]) >>> ig_results = evaluator.evaluate_attribution(test_loader, ig) >>> chefer_results = evaluator.evaluate_attribution( ... test_loader, chefer ... )
Removal-Based Metrics#
Base Class#
- class pyhealth.metrics.interpretability.base.RemovalBasedMetric(model, percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', *, sample_filter)[source]#
Bases:
ABCAbstract base class for removal-based interpretability metrics.
This class provides common functionality for computing faithfulness metrics by removing or retaining features based on their importance scores.
- Parameters:
model (
BaseModel) – PyHealth BaseModel that accepts **kwargs and returns dict with ‘y_prob’ or ‘logit’.percentages (
List[float]) – List of percentages to evaluate at. Default: [1, 5, 10, 20, 50].ablation_strategy (
str) – How to ablate features. Options: - ‘zero’: Set ablated features to 0 - ‘mean’: Set ablated features to feature mean across batch - ‘noise’: Add Gaussian noise to ablated features Default: ‘zero’.sample_filter (
Callable[[Tensor,str],Tensor]) – A callable that classifies each sample for evaluation. Signature: (class_probs, classifier_type) -> sample_classes where class_probs has shape (batch_size,) and contains the probability for the predicted class (sigmoid/softmax output with target class already applied), and sample_classes is a tensor of SampleClass values. - SampleClass.POSITIVE: evaluate with attributions as-is - SampleClass.NEGATIVE: evaluate with negated attributions - SampleClass.IGNORE: exclude from evaluation
- compute(inputs, attributions, predicted_class=None, return_per_percentage=False, debug=False)[source]#
Compute metric across percentages.
- Parameters:
attributions (
Dict[str,Tensor]) – Attribution scores matching input shapes (batch)predicted_class (
Optional[Tensor]) – Optional pre-computed predicted classes (batch)return_per_percentage (
bool) – If True, return dict mapping percentage -> scores. If False (default), return averaged score across percentages.debug (
bool) – If True, print detailed probability information (only used when return_per_percentage=True)
- Returns:
- Tuple of (metric_scores, valid_mask):
- metric_scores: Average scores across percentages,
shape (batch_size,)
valid_mask: Binary mask indicating valid samples
- If return_per_percentage=True:
Dict[float, torch.Tensor]: Maps percentage -> scores (batch_size,). For binary classifiers, negative class samples have value 0.
- Return type:
If return_per_percentage=False (default)
Note
For binary classifiers, all samples are evaluated (both positive and negative predictions). For class 0 predictions, attributions are negated internally so that feature importance is measured relative to the predicted class.
Comprehensiveness#
- class pyhealth.metrics.interpretability.comprehensiveness.ComprehensivenessMetric(model, percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', *, sample_filter)[source]#
Bases:
RemovalBasedMetricComprehensiveness metric for interpretability evaluation.
Measures the drop in predicted class probability when important features are REMOVED (ablated). Higher scores indicate more faithful interpretations.
- The metric is computed as:
- COMP = (1/|B|) × Σ[p_c(x)(x) - p_c(x)(x x:q%)]
q∈B
- Where:
x is the original input
x:q% are the top q% most important features
x x:q% is input with top q% features removed (ablated)
p_c(x)(·) is predicted probability for original predicted class
B is the set of percentages (default: {1, 5, 10, 20, 50})
Examples
>>> import torch >>> from pyhealth.models import MLP >>> from pyhealth.metrics.interpretability import ( ... ComprehensivenessMetric ... ) >>> >>> # Assume we have a trained model >>> model = MLP(dataset=dataset) >>> >>> # Initialize metric >>> comp = ComprehensivenessMetric(model) >>> >>> # Prepare inputs and attributions >>> inputs = {'conditions': torch.randn(32, 50)} >>> attributions = {'conditions': torch.randn(32, 50)} >>> >>> # Compute metric >>> scores, valid_mask = comp.compute(inputs, attributions) >>> print(f"Mean comprehensiveness: {scores[valid_mask].mean():.3f}") Mean comprehensiveness: 0.234 >>> >>> # Get detailed scores per percentage >>> detailed = comp.compute( ... inputs, attributions, return_per_percentage=True ... ) >>> for pct, scores in detailed.items(): ... print(f" {pct}%: {scores.mean():.3f}") 1%: 0.045 5%: 0.123 10%: 0.234 20%: 0.345 50%: 0.456
- compute(inputs, attributions, predicted_class=None, return_per_percentage=False, debug=False)#
Compute metric across percentages.
- Parameters:
attributions (
Dict[str,Tensor]) – Attribution scores matching input shapes (batch)predicted_class (
Optional[Tensor]) – Optional pre-computed predicted classes (batch)return_per_percentage (
bool) – If True, return dict mapping percentage -> scores. If False (default), return averaged score across percentages.debug (
bool) – If True, print detailed probability information (only used when return_per_percentage=True)
- Returns:
- Tuple of (metric_scores, valid_mask):
- metric_scores: Average scores across percentages,
shape (batch_size,)
valid_mask: Binary mask indicating valid samples
- If return_per_percentage=True:
Dict[float, torch.Tensor]: Maps percentage -> scores (batch_size,). For binary classifiers, negative class samples have value 0.
- Return type:
If return_per_percentage=False (default)
Note
For binary classifiers, all samples are evaluated (both positive and negative predictions). For class 0 predictions, attributions are negated internally so that feature importance is measured relative to the predicted class.
Sufficiency#
- class pyhealth.metrics.interpretability.sufficiency.SufficiencyMetric(model, percentages=[1, 5, 10, 20, 50], ablation_strategy='zero', *, sample_filter)[source]#
Bases:
RemovalBasedMetricSufficiency metric for interpretability evaluation.
Measures the drop in predicted class probability when ONLY important features are KEPT (all others removed). Lower scores indicate more faithful interpretations.
- The metric is computed as:
- SUFF = (1/|B|) × Σ[p_c(x)(x) - p_c(x)(x:q%)]
q∈B
- Where:
x is the original input
x:q% are the top q% most important features (all others removed)
p_c(x)(·) is predicted probability for original predicted class
B is the set of percentages (default: {1, 5, 10, 20, 50})
Examples
>>> import torch >>> from pyhealth.models import MLP >>> from pyhealth.metrics.interpretability import SufficiencyMetric >>> >>> # Assume we have a trained model >>> model = MLP(dataset=dataset) >>> >>> # Initialize metric >>> suff = SufficiencyMetric(model) >>> >>> # Prepare inputs and attributions >>> inputs = {'conditions': torch.randn(32, 50)} >>> attributions = {'conditions': torch.randn(32, 50)} >>> >>> # Compute metric >>> scores, valid_mask = suff.compute(inputs, attributions) >>> print(f"Mean sufficiency: {scores[valid_mask].mean():.3f}") Mean sufficiency: 0.089 >>> >>> # Get detailed scores per percentage >>> detailed = suff.compute( ... inputs, attributions, return_per_percentage=True ... ) >>> for pct, scores in detailed.items(): ... print(f" {pct}%: {scores.mean():.3f}") 1%: 0.234 5%: 0.178 10%: 0.089 20%: 0.045 50%: 0.012
- compute(inputs, attributions, predicted_class=None, return_per_percentage=False, debug=False)#
Compute metric across percentages.
- Parameters:
attributions (
Dict[str,Tensor]) – Attribution scores matching input shapes (batch)predicted_class (
Optional[Tensor]) – Optional pre-computed predicted classes (batch)return_per_percentage (
bool) – If True, return dict mapping percentage -> scores. If False (default), return averaged score across percentages.debug (
bool) – If True, print detailed probability information (only used when return_per_percentage=True)
- Returns:
- Tuple of (metric_scores, valid_mask):
- metric_scores: Average scores across percentages,
shape (batch_size,)
valid_mask: Binary mask indicating valid samples
- If return_per_percentage=True:
Dict[float, torch.Tensor]: Maps percentage -> scores (batch_size,). For binary classifiers, negative class samples have value 0.
- Return type:
If return_per_percentage=False (default)
Note
For binary classifiers, all samples are evaluated (both positive and negative predictions). For class 0 predictions, attributions are negated internally so that feature importance is measured relative to the predicted class.
Utility Functions#
- pyhealth.metrics.interpretability.utils.get_model_predictions(model, inputs, classifier_type, sample_filter=None, sample_class=None, target_class_idx=None)[source]#
Get model predictions, probabilities, and class-specific probabilities.
- Parameters:
model (
BaseModel) – PyHealth BaseModel that returns dict with ‘y_prob’ or ‘logit’classifier_type (
str) – One of ‘binary’, ‘multiclass’, ‘multilabel’, ‘unknown’target_class_idx (
Optional[Tensor]) – (Optional) Pre-computed target class indices, this would ensure ablated runs are consistent with original predictions. If None, will compute from model outputs.sample_filter (
Optional[Callable[[Tensor,str],Tensor]]) – A callable that classifies each sample for evaluation. Signature: (class_probs, classifier_type) -> sample_classes where class_probs has shape (batch_size,). For binary single-logit models this isP(class=1); otherwise it is the gathered target-class probability.sample_classesis a tensor of SampleClass values.
- Returns:
- y_prob: All class probabilities
Binary: shape (batch_size, 1), values are P(class=1)
Multiclass: shape (batch_size, num_classes)
target_class_idx: Target class indices, shape (batch_size,)
sample_classes: SampleClass values for each sample, shape (batch_size,)
- Return type:
Tuple of (y_prob, target_class_idx, sample_classes)