Source code for pyhealth.metrics.interpretability.base

"""Base class for removal-based interpretability metrics.

This module provides the abstract base class for removal-based faithfulness
metrics like Comprehensiveness and Sufficiency.
"""

from abc import ABC, abstractmethod
from typing import Dict, List, Optional

import torch

from pyhealth.models import BaseModel

from .utils import (
    SampleClass,
    SampleFilterFn,
    get_model_predictions,
)


[docs]class RemovalBasedMetric(ABC): """Abstract base class for removal-based interpretability metrics. This class provides common functionality for computing faithfulness metrics by removing or retaining features based on their importance scores. Args: model: PyHealth BaseModel that accepts **kwargs and returns dict with 'y_prob' or 'logit'. percentages: List of percentages to evaluate at. Default: [1, 5, 10, 20, 50]. ablation_strategy: How to ablate features. Options: - 'zero': Set ablated features to 0 - 'mean': Set ablated features to feature mean across batch - 'noise': Add Gaussian noise to ablated features Default: 'zero'. sample_filter: A callable that classifies each sample for evaluation. Signature: (class_probs, classifier_type) -> sample_classes where class_probs has shape (batch_size,) and contains the probability for the predicted class (sigmoid/softmax output with target class already applied), and sample_classes is a tensor of SampleClass values. - SampleClass.POSITIVE: evaluate with attributions as-is - SampleClass.NEGATIVE: evaluate with negated attributions - SampleClass.IGNORE: exclude from evaluation """ def __init__( self, model: BaseModel, percentages: List[float] = [1, 5, 10, 20, 50], ablation_strategy: str = "zero", *, sample_filter: SampleFilterFn, ): self.model = model self.percentages = percentages self.ablation_strategy = ablation_strategy self._sample_filter = sample_filter self.model.eval() # Detect classifier type from model self._detect_classifier_type() def _detect_classifier_type(self): """Detect classifier type from model's output schema. Sets self.classifier_type and self.num_classes based on model. Expected model types: - PyHealth BaseModel with dataset.output_schema - Custom models following same interface Classifier types: - binary: Binary classification, output [batch, 1] - multiclass: Multi-class classification, output [batch, C] - multilabel: Multi-label classification, output [batch, L] """ # Check if model is a PyHealth BaseModel with dataset if not hasattr(self.model, "dataset") or not hasattr( self.model.dataset, "output_schema" ): self.classifier_type = "unknown" self.num_classes = None print("[RemovalBasedMetric] WARNING: Cannot detect type") print(" - Model missing dataset.output_schema") print(" - Expected: PyHealth BaseModel or compatible") return # Get output schema output_schema = self.model.dataset.output_schema if len(output_schema) == 0: self.classifier_type = "unknown" self.num_classes = None print("[RemovalBasedMetric] WARNING: Empty output_schema") return # Use first label key (most common case) label_key = list(output_schema.keys())[0] schema_entry = output_schema[label_key] # Use BaseModel's _resolve_mode if available, else manual check if hasattr(self.model, "_resolve_mode"): try: mode = self.model._resolve_mode(schema_entry) except Exception as e: self.classifier_type = "unknown" self.num_classes = None print(f"[RemovalBasedMetric] WARNING: {e}") return else: # Fallback: check string or class name if isinstance(schema_entry, str): mode = schema_entry.lower() elif hasattr(schema_entry, "__name__"): mode = schema_entry.__name__.lower() else: mode = "unknown" # Set classifier type based on mode if mode == "binary": self.classifier_type = "binary" self.num_classes = 2 print("[RemovalBasedMetric] Detected BINARY classifier") print(" - Output shape: [batch, 1] with P(class=1)") print(" - Evaluates both positive and negative predictions") elif mode == "multiclass": self.classifier_type = "multiclass" # Get num_classes from processor if hasattr(self.model.dataset, "output_processors"): processor = self.model.dataset.output_processors.get(label_key) self.num_classes = processor.size() if processor else None else: self.num_classes = None print("[RemovalBasedMetric] Detected MULTICLASS classifier") print(f" - Num classes: {self.num_classes}") print(f" - Output shape: [batch, {self.num_classes}]") elif mode == "multilabel": self.classifier_type = "multilabel" # Get num_labels from processor if hasattr(self.model.dataset, "output_processors"): processor = self.model.dataset.output_processors.get(label_key) self.num_classes = processor.size() if processor else None else: self.num_classes = None print("[RemovalBasedMetric] Detected MULTILABEL classifier") print(f" - Num labels: {self.num_classes}") print(f" - Output shape: [batch, {self.num_classes}]") print(" - NOTE: Multilabel support not fully tested") else: self.classifier_type = "unknown" self.num_classes = None print("[RemovalBasedMetric] WARNING: Unknown classifier") print(f" - Mode detected: {mode}") @abstractmethod def _create_ablated_inputs( self, inputs: Dict[str, torch.Tensor], masks: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: """Create ablated version of inputs based on masks. Args: inputs: Original model inputs masks: Binary masks (1=keep/remove depending on metric) Returns: Ablated inputs with same structure as inputs """ pass def _compute_threshold_and_mask( self, attributions: Dict[str, torch.Tensor], percentage: float, ) -> Dict[str, torch.Tensor]: """Compute binary masks for top-percentage features using quantiles. Args: attributions: Attribution scores percentage: Percentage of features to select (e.g., 10 for top 10%) Returns: Dictionary mapping feature_key to binary mask (1 for top-percentage) """ masks = {} for key, attr in attributions.items(): # Compute per-sample masks batch_size = attr.shape[0] mask = torch.zeros_like(attr) # Convert percentage to quantile (e.g., top 10% = 90th percentile) # percentage = 10 means top 10%, so quantile = 1 - 0.10 = 0.90 quantile = 1.0 - (percentage / 100.0) for i in range(batch_size): attr_sample = attr[i].flatten() # Handle edge case: if all values are the same if attr_sample.min() == attr_sample.max(): # Select approximately the right percentage of features num_features = attr_sample.numel() num_to_select = max(1, int(num_features * percentage / 100.0)) mask_flat = torch.zeros_like(attr_sample) mask_flat[:num_to_select] = 1.0 else: # Compute threshold using quantile # Use "higher" interpolation to be conservative: when the # quantile falls between two values, we use the higher # threshold, ensuring we select at least the target %. # This matches the behavior of topk which includes all # values tied at the boundary. threshold = torch.quantile( attr_sample, quantile, interpolation="higher" ) # Create mask for values >= threshold mask_flat = (attr_sample >= threshold).float() mask[i] = mask_flat.reshape(attr[i].shape) masks[key] = mask return masks def _apply_ablation( self, inputs: Dict[str, torch.Tensor], masks: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: """Apply ablation strategy to create modified inputs. Args: inputs: Original inputs masks: Binary masks indicating which features to ablate (1=ablate) Returns: Modified inputs with ablation applied """ ablated_inputs = {} for key in inputs.keys(): x = inputs[key] # Handle tuple inputs (e.g., StageNet's (time, values) format) if isinstance(x, tuple): # Extract values from tuple (typically (time, values)) if len(x) >= 2: time_info = x[0] x_values = x[1] # If no mask for this key, keep unchanged if key not in masks: ablated_inputs[key] = x continue mask = masks[key] # Check if values are integers (discrete features) is_discrete = x_values.dtype in [ torch.long, torch.int, torch.int32, torch.int64, ] # Apply ablation to values part if is_discrete: # For discrete features (codes), multiply by (1-mask) # Where mask=1 (ablate): set to 0 (padding index) # Where mask=0 (keep): preserve original value ablated_values = x_values * (1 - mask).long() # Safety: prevent complete ablation of sequences # Complete ablation (all zeros) causes issues in # StageNet: # - All embeddings become zero (padding_idx=0) # - Mask becomes all zeros # - get_last_visit() tries to index with -1 # Solution: keep at least one non-zero element for b in range(ablated_values.shape[0]): if ablated_values[b].sum() == 0: non_zero_mask = x_values[b] != 0 if non_zero_mask.any(): # Keep first non-zero element first_idx = non_zero_mask.nonzero()[0] ablated_values[b][tuple(first_idx)] = x_values[b][ tuple(first_idx) ] else: # For continuous features, apply standard ablation if self.ablation_strategy == "zero": ablated_values = x_values * (1 - mask) elif self.ablation_strategy == "mean": x_mean = x_values.mean(dim=0, keepdim=True) ablated_values = x_values * (1 - mask) + x_mean * mask elif self.ablation_strategy == "noise": noise = torch.randn_like(x_values) * x_values.std() ablated_values = x_values * (1 - mask) + noise * mask else: raise ValueError( f"Unknown ablation strategy: " f"{self.ablation_strategy}" ) # Reconstruct tuple with ablated values ablated_inputs[key] = (time_info, ablated_values) + x[2:] else: # Tuple with unexpected length, keep unchanged ablated_inputs[key] = x continue # Skip non-tensor, non-tuple inputs (like lists, strings) if not isinstance(x, torch.Tensor): ablated_inputs[key] = x continue # If no mask for this key, keep unchanged if key not in masks: ablated_inputs[key] = x.clone() continue mask = masks[key] # Apply ablation strategy if self.ablation_strategy == "zero": # Set ablated features to 0 ablated_inputs[key] = x * (1 - mask) elif self.ablation_strategy == "mean": # Set ablated features to mean across batch x_mean = x.mean(dim=0, keepdim=True) ablated_inputs[key] = x * (1 - mask) + x_mean * mask elif self.ablation_strategy == "noise": # Replace ablated features with Gaussian noise noise = torch.randn_like(x) * x.std() ablated_inputs[key] = x * (1 - mask) + noise * mask else: raise ValueError(f"Unknown ablation strategy: {self.ablation_strategy}") return ablated_inputs
[docs] def compute( self, inputs: Dict[str, torch.Tensor], attributions: Dict[str, torch.Tensor], predicted_class: Optional[torch.Tensor] = None, return_per_percentage: bool = False, debug: bool = False, ): """Compute metric across percentages. Args: inputs: Model inputs (batch) attributions: Attribution scores matching input shapes (batch) predicted_class: Optional pre-computed predicted classes (batch) return_per_percentage: If True, return dict mapping percentage -> scores. If False (default), return averaged score across percentages. debug: If True, print detailed probability information (only used when return_per_percentage=True) Returns: If return_per_percentage=False (default): Tuple of (metric_scores, valid_mask): - metric_scores: Average scores across percentages, shape (batch_size,) - valid_mask: Binary mask indicating valid samples If return_per_percentage=True: Dict[float, torch.Tensor]: Maps percentage -> scores (batch_size,). For binary classifiers, negative class samples have value 0. Note: For binary classifiers, all samples are evaluated (both positive and negative predictions). For class 0 predictions, attributions are negated internally so that feature importance is measured relative to the predicted class. """ # Get original predictions (returns 3 values) y_probs, target_class_idx, sample_class = get_model_predictions( model=self.model, inputs=inputs, classifier_type=self.classifier_type, sample_filter=self._sample_filter, ) batch_size = y_probs.shape[0] # Validity mask: IGNORE samples excluded val_mask = sample_class != SampleClass.IGNORE # For NEGATIVE samples, negate attributions so that # "top features" become those most important for the predicted # class (features with low class-1 attribution support class 0). neg_mask = sample_class == SampleClass.NEGATIVE if neg_mask.any(): attributions = { key: torch.where( neg_mask.view(-1, *([1] * (attr.dim() - 1))), -attr, attr, ) for key, attr in attributions.items() } # Debug output (if requested and returning per percentage) if debug and return_per_percentage: print(f"\n{'='*80}") print(f"[DEBUG] compute for {self.__class__.__name__}") print(f"{'='*80}") print(f"Batch size: {batch_size}") print(f"Classifier type: {self.classifier_type}") if self.classifier_type == "binary": print(f"Positive class samples: {(sample_class == SampleClass.POSITIVE).sum().item()}") print(f"Negative class samples: {(sample_class == SampleClass.NEGATIVE).sum().item()}") print("NOTE: Evaluating BOTH positive and negative predictions") print("\nOriginal probs for predicted class:") for i, prob in enumerate(y_probs): cls = target_class_idx[i].item() print(f" Sample {i} [class={cls}]: {prob.item():.6f}") # Store results per percentage if return_per_percentage: results = {} else: # Accumulator for averaging metric_scores = torch.zeros(batch_size, device=y_probs.device) # Compute metrics across all percentages for percentage in self.percentages: # Compute masks for this percentage masks = self._compute_threshold_and_mask(attributions, percentage) # Create ablated inputs (subclass-specific) ablated_inputs = self._create_ablated_inputs(inputs, masks) # Get predictions on ablated inputs ablated_probs, _, _ = get_model_predictions( model=self.model, inputs=ablated_inputs, target_class_idx=target_class_idx, # Use same predicted classes from original to avoid shifts sample_class=sample_class, # Use same sample classes to ensure consistency classifier_type=self.classifier_type, ) # Compute probability drop original_class_probs = y_probs original_class_probs[neg_mask] = -original_class_probs[neg_mask] ablated_class_probs = ablated_probs ablated_class_probs[neg_mask] = -ablated_class_probs[neg_mask] prob_drop = torch.zeros(batch_size, device=y_probs.device) prob_drop[val_mask] = ( original_class_probs[val_mask] - ablated_class_probs[val_mask] ) # Debug output for this percentage if debug and return_per_percentage: print(f"\n{'-'*80}") print(f"Percentage: {percentage}%") print(f"{'-'*80}") print(f"Ablated probs shape: {ablated_probs.shape}") if self.classifier_type == "binary": print("\nAblated probabilities P(class=1):") for i, prob in enumerate(ablated_probs): cls = target_class_idx[i].item() print(f" Sample {i} [class={cls}]: {prob.item():.6f}") else: print("\nAblated probabilities (all classes):") for i, probs in enumerate(ablated_probs): print(f" Sample {i}: {probs.tolist()}") print("\nProbability drops (original - ablated):") for i, drop in enumerate(prob_drop): orig = original_class_probs[i].item() abl = ablated_class_probs[i].item() cls = target_class_idx[i].item() print( f" Sample {i} [class={cls}]: {drop.item():.6f} " f"({orig:.6f} - {abl:.6f})" ) # Check for unexpected negative values evaluated_drops = prob_drop[val_mask] neg_mask = evaluated_drops < 0 if neg_mask.any(): neg_count = neg_mask.sum().item() print(f"\n⚠ WARNING: {neg_count} negative detected!") print(" Negative values mean ablation INCREASED " "confidence,") print(" which suggests:") if self.__class__.__name__ == "ComprehensivenessMetric": print(" - Removed features were HARMING " "predictions") print(" - Attribution may have opposite signs") else: # SufficiencyMetric print(" - Kept features performed WORSE than full") print(" - Attribution quality may be poor") if return_per_percentage: results[percentage] = prob_drop # type: ignore else: # Accumulate for averaging metric_scores = metric_scores + prob_drop # type: ignore # Return appropriate format if return_per_percentage: return results # type: ignore else: # Average across percentages metric_scores = metric_scores / len(self.percentages) # type: ignore return metric_scores, val_mask