Source code for pyhealth.calib.predictionset.covariate.covariate_label

"""
Covariate Shift Adaptive Conformal Prediction.

This module implements conformal prediction with covariate shift correction
using likelihood ratio weighting. The implementation supports both:
1. KDE-based density estimation for automatic weight computation
2. User-provided custom weights for flexibility

The KDE-based correction approach is based on the CoDrug method, which uses
energy-based models and kernel density estimation to assess molecular densities
and construct weighted conformal prediction sets.

Papers:
    Tibshirani, Ryan J., Rina Foygel Barber, Emmanuel Candes, and
    Aaditya Ramdas. "Conformal prediction under covariate shift."
    Advances in neural information processing systems 32 (2019).
    https://arxiv.org/abs/1904.06019
    
    Laghuvarapu, Siddhartha, Zhen Lin, and Jimeng Sun.
    "Conformal Drug Property Prediction with Density Estimation under 
    Covariate Shift." NeurIPS 2023.
    https://arxiv.org/abs/2310.12033
"""

from typing import Callable, Dict, Optional, Union

import numpy as np
import torch
from torch.utils.data import IterableDataset

from pyhealth.calib.base_classes import SetPredictor
from pyhealth.calib.calibration.kcal.kde import RBFKernelMean
from pyhealth.calib.utils import prepare_numpy_dataset
from pyhealth.datasets import get_dataloader
from pyhealth.models import BaseModel

__all__ = ["CovariateLabel", "fit_kde"]


[docs]def fit_kde( cal_embeddings: np.ndarray, test_embeddings: np.ndarray, bandwidth: Optional[Union[float, str]] = "scott", kernel: str = "rbf", ) -> tuple[Callable, Callable]: """Fit KDEs on calibration and test embeddings using PyHealth's KDE. This implements the KDE-based density estimation approach from the CoDrug paper (Laghuvarapu et al., NeurIPS 2023) for computing likelihood ratios under covariate shift. The method uses kernel density estimation on both calibration and test embeddings to estimate p_test(x) / p_cal(x). This uses the PyHealth torch-based RBF kernel density estimator which is more efficient than sklearn for GPU computation. Reference: Laghuvarapu, S., Lin, Z., & Sun, J. (2023). Conformal Drug Property Prediction with Density Estimation under Covariate Shift. NeurIPS 2023. https://arxiv.org/abs/2310.12033 Args: cal_embeddings: Calibration embeddings as numpy array of shape (n_cal_samples, embedding_dim) test_embeddings: Test embeddings as numpy array of shape (n_test_samples, embedding_dim) bandwidth: Bandwidth for KDE. Can be: - "scott": Use Scott's rule (default) - float: Use specified bandwidth kernel: Kernel type. Currently only "rbf" is supported. Returns: Tuple of (kde_cal, kde_test) where each is a callable that takes embeddings and returns density estimates. Examples: >>> import numpy as np >>> from pyhealth.calib.predictionset.covariate import fit_kde >>> >>> # Extract embeddings from your model >>> cal_embeddings = np.random.randn(100, 64) >>> test_embeddings = np.random.randn(50, 64) >>> >>> # Fit KDEs >>> kde_cal, kde_test = fit_kde(cal_embeddings, test_embeddings) >>> >>> # Use in CovariateLabel >>> from pyhealth.calib.predictionset.covariate import ( ... CovariateLabel) >>> predictor = CovariateLabel( ... model=model, ... alpha=0.1, ... kde_cal=kde_cal, ... kde_test=kde_test ... ) """ if kernel != "rbf": raise ValueError(f"Only 'rbf' kernel supported, got {kernel}") # Calculate bandwidth if needed def get_bandwidth(embeddings, bw): if isinstance(bw, str): n_samples, n_features = embeddings.shape if bw == "scott": return n_samples ** (-1.0 / (n_features + 4)) else: raise ValueError(f"Unknown bandwidth method: {bw}") return bw # Convert to torch tensors cal_emb_torch = torch.from_numpy(cal_embeddings).float() test_emb_torch = torch.from_numpy(test_embeddings).float() # Fit KDE on calibration embeddings cal_bw = get_bandwidth(cal_embeddings, bandwidth) kern_cal = RBFKernelMean(h=cal_bw) # Fit KDE on test embeddings test_bw = get_bandwidth(test_embeddings, bandwidth) kern_test = RBFKernelMean(h=test_bw) # Create callable functions that compute density def kde_cal(data): """Compute density using calibration KDE.""" if not isinstance(data, torch.Tensor): data = torch.from_numpy(np.array(data)).float() if data.ndim == 1: data = data.unsqueeze(0) # Compute kernel values and average (density estimate) with torch.no_grad(): K = kern_cal(data, cal_emb_torch) # (n_query, n_cal) density = K.mean(dim=1) # Average over calibration points return density.numpy() def kde_test(data): """Compute density using test KDE.""" if not isinstance(data, torch.Tensor): data = torch.from_numpy(np.array(data)).float() if data.ndim == 1: data = data.unsqueeze(0) # Compute kernel values and average (density estimate) with torch.no_grad(): K = kern_test(data, test_emb_torch) # (n_query, n_test) density = K.mean(dim=1) # Average over test points return density.numpy() return kde_cal, kde_test
def _compute_likelihood_ratio( kde_test: Callable, kde_cal: Callable, data: np.ndarray ) -> np.ndarray: """Compute likelihood ratio for covariate shift correction. Args: kde_test: Density estimator fitted on test distribution kde_cal: Density estimator fitted on calibration distribution data: Input data to compute likelihood ratio for Returns: Likelihood ratios (test density / calibration density) """ test_density = kde_test(data) cal_density = kde_cal(data) # Add small epsilon to avoid division by zero return test_density / (cal_density + 1e-10) def _query_weighted_quantile( scores: np.ndarray, alpha: float, weights: np.ndarray ) -> float: """Compute weighted quantile of scores. Args: scores: Array of conformity scores alpha: Quantile level (between 0 and 1) weights: Weights for each score Returns: The weighted alpha-quantile of scores """ # Sort scores and corresponding weights sorted_indices = np.argsort(scores) sorted_scores = scores[sorted_indices] sorted_weights = weights[sorted_indices] # Compute cumulative weights cum_weights = np.cumsum(sorted_weights) / np.sum(sorted_weights) # Find the index where cumulative weight exceeds alpha idx = np.searchsorted(cum_weights, alpha, side="left") # Handle edge cases if idx >= len(sorted_scores): idx = len(sorted_scores) - 1 return sorted_scores[idx]
[docs]class CovariateLabel(SetPredictor): """Covariate shift adaptive conformal prediction for multiclass. This extends the LABEL method to handle covariate shift between calibration and test distributions using likelihood ratio weighting. The method maintains coverage guarantees under covariate shift by reweighting calibration examples according to the likelihood ratio between test and calibration densities. The default KDE-based approach follows the CoDrug method (Laghuvarapu et al., NeurIPS 2023), which uses kernel density estimation on embeddings to compute likelihood ratios. Alternatively, users can provide custom weights directly for more flexibility (e.g., from importance sampling, propensity scores, or domain-specific methods). Papers: Tibshirani, Ryan J., Rina Foygel Barber, Emmanuel Candes, and Aaditya Ramdas. "Conformal prediction under covariate shift." Advances in neural information processing systems 32 (2019). https://arxiv.org/abs/1904.06019 Laghuvarapu, Siddhartha, Zhen Lin, and Jimeng Sun. "Conformal Drug Property Prediction with Density Estimation under Covariate Shift." NeurIPS 2023. https://arxiv.org/abs/2310.12033 Args: model: A trained base model alpha: Target mis-coverage rate(s). Can be: - float: marginal coverage P(Y not in C(X)) <= alpha - array: class-conditional P(Y not in C(X) | Y=k) <= alpha[k] kde_test: Optional density estimator fitted on test distribution. Should be a callable that takes embeddings (numpy array) and returns density estimates. Can be obtained via fit_kde(). Used for KDE-based likelihood ratio weighting (CoDrug approach). kde_cal: Optional density estimator fitted on calibration distribution. Should be a callable that takes embeddings (numpy array) and returns density estimates. Used for KDE-based likelihood ratio weighting (CoDrug approach). debug: Whether to use debug mode (processes fewer samples for faster iteration) Examples: **Example 1: KDE-based approach (CoDrug method)** >>> from pyhealth.datasets import ISRUCDataset >>> from pyhealth.datasets import split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.predictionset.covariate import ( ... CovariateLabel, fit_kde) >>> import numpy as np >>> >>> # Prepare data >>> sleep_ds = ISRUCDataset("/data/ISRUC-I").set_task( ... sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient( ... sleep_ds, [0.6, 0.2, 0.2]) >>> >>> # Train model >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... training code ... >>> >>> # Extract embeddings (example - adjust for your model) >>> def extract_embeddings(model, dataset): ... loader = get_dataloader(dataset, batch_size=32, ... shuffle=False) ... all_embs = [] ... for batch in loader: ... batch['embed'] = True ... output = model(**batch) ... all_embs.append(output['embed'].cpu().numpy()) ... return np.concatenate(all_embs, axis=0) >>> >>> cal_embs = extract_embeddings(model, val_data) >>> test_embs = extract_embeddings(model, test_data) >>> >>> # KDE-based approach: automatically compute weights >>> cal_model = CovariateLabel(model, alpha=0.1) >>> cal_model.calibrate(cal_dataset=val_data, ... cal_embeddings=cal_embs, test_embeddings=test_embs) >>> >>> # Evaluate >>> test_dl = get_dataloader(test_data, batch_size=32, ... shuffle=False) >>> from pyhealth.trainer import Trainer, get_metrics_fn >>> y_true_all, y_prob_all, _, extra_output = Trainer( ... model=cal_model).inference(test_dl, ... additional_outputs=['y_predset']) >>> print(get_metrics_fn(cal_model.mode)( ... y_true_all, y_prob_all, metrics=['accuracy', 'miscoverage_ps'], ... y_predset=extra_output['y_predset'])) **Example 2: Custom weights approach** >>> # If you have your own covariate shift correction method >>> # (e.g., importance sampling, propensity scores, etc.) >>> def compute_custom_weights(cal_data, test_data): ... # Your custom weight computation ... # Should return weights proportional to p_test(x) / p_cal(x) ... return weights # shape: (n_cal,) >>> >>> custom_weights = compute_custom_weights(val_data, test_data) >>> cal_model = CovariateLabel(model, alpha=0.1) >>> cal_model.calibrate(cal_dataset=val_data, cal_weights=custom_weights) """
[docs] def __init__( self, model: BaseModel, alpha: Union[float, np.ndarray], kde_test: Optional[Callable] = None, kde_cal: Optional[Callable] = None, debug: bool = False, **kwargs, ) -> None: super().__init__(model, **kwargs) if model.mode != "multiclass": raise NotImplementedError( "CovariateLabel only supports multiclass classification" ) self.mode = self.model.mode # Freeze model parameters for param in model.parameters(): param.requires_grad = False self.model.eval() self.device = model.device self.debug = debug # Store alpha if not isinstance(alpha, float): alpha = np.asarray(alpha) self.alpha = alpha # Store density estimators if kde_test is not None and kde_cal is not None: self.kde_test = kde_test self.kde_cal = kde_cal else: self.kde_test = None self.kde_cal = None # Will be set during calibration self.t = None self._sum_cal_weights = None
[docs] def calibrate( self, cal_dataset: IterableDataset, cal_embeddings: Optional[np.ndarray] = None, test_embeddings: Optional[np.ndarray] = None, cal_weights: Optional[np.ndarray] = None, ): """Calibrate the thresholds with covariate shift correction. This method supports three approaches for handling covariate shift: 1. **KDE-based (CoDrug approach)**: Provide cal_embeddings and test_embeddings (and optionally kde_test/kde_cal). The method will use kernel density estimation to compute likelihood ratios. 2. **Custom weights**: Directly provide cal_weights computed from your own covariate shift correction method (e.g., importance sampling, propensity scores, discriminator-based methods, etc.). 3. **Pre-fitted KDEs**: Provide kde_test and kde_cal during initialization along with cal_embeddings here. Args: cal_dataset: Calibration set cal_embeddings: Optional pre-computed calibration embeddings of shape (n_cal, embedding_dim). If provided along with test_embeddings and KDEs are not set, will be used to compute likelihood ratios via KDE (CoDrug approach). test_embeddings: Optional pre-computed test embeddings of shape (n_test, embedding_dim). Used with cal_embeddings for KDE-based likelihood ratio computation. cal_weights: Optional custom weights for calibration samples of shape (n_cal,). If provided, these weights will be used directly instead of computing likelihood ratios via KDE. Weights should represent importance weights or likelihood ratios p_test(x) / p_cal(x). These will be normalized internally. Note: You must provide ONE of: 1. cal_weights (custom weights), OR 2. kde_test and kde_cal during initialization, OR 3. cal_embeddings and test_embeddings here Examples: >>> # Approach 1: KDE-based (CoDrug) >>> model.calibrate(cal_dataset, cal_embeddings, test_embeddings) >>> >>> # Approach 2: Custom weights (e.g., from importance sampling) >>> custom_weights = compute_importance_weights(cal_data, test_data) >>> model.calibrate(cal_dataset, cal_weights=custom_weights) """ # Get predictions and true labels first cal_dataset_dict = prepare_numpy_dataset( self.model, cal_dataset, ["y_prob", "y_true"], debug=self.debug, ) y_prob = cal_dataset_dict["y_prob"] y_true = cal_dataset_dict["y_true"] N, K = y_prob.shape # Determine weights: either custom or KDE-based if cal_weights is not None: # Use custom weights provided by user if len(cal_weights) != N: raise ValueError( f"cal_weights must have length {N} (size of calibration set), " f"got {len(cal_weights)}" ) likelihood_ratios = np.asarray(cal_weights, dtype=np.float64) print("Using custom calibration weights") else: # Use KDE-based approach (CoDrug method) # Check if we have KDEs if self.kde_test is None or self.kde_cal is None: if cal_embeddings is None or test_embeddings is None: raise ValueError( "Must provide ONE of:\n" " 1. cal_weights (custom weights), OR\n" " 2. kde_test and kde_cal during __init__, OR\n" " 3. cal_embeddings and test_embeddings during calibrate()" ) # Fit KDEs if embeddings provided print("Fitting KDEs on provided embeddings (CoDrug approach)...") self.kde_cal, self.kde_test = fit_kde(cal_embeddings, test_embeddings) # Use provided embeddings or extract from calibration data if cal_embeddings is not None: X = cal_embeddings else: # KDEs should already be provided in this case # We just need to get the embeddings for likelihood ratio # This assumes the model outputs embeddings raise NotImplementedError( "Automatic embedding extraction not yet supported. " "Please provide cal_embeddings and test_embeddings." ) # Compute likelihood ratios using KDE print("Computing likelihood ratios via KDE...") likelihood_ratios = _compute_likelihood_ratio( self.kde_test, self.kde_cal, X ) # Normalize weights weights = likelihood_ratios / np.sum(likelihood_ratios) self._sum_cal_weights = np.sum(likelihood_ratios) # Extract conformity scores (probabilities of true class) conformity_scores = y_prob[np.arange(N), y_true] # Compute weighted quantile thresholds if isinstance(self.alpha, float): # Marginal coverage: single threshold t = _query_weighted_quantile(conformity_scores, self.alpha, weights) else: # Class-conditional coverage: one threshold per class t = [] for k in range(K): mask = y_true == k if np.sum(mask) > 0: class_scores = conformity_scores[mask] class_weights = weights[mask] # Renormalize class weights class_weights = class_weights / np.sum(class_weights) t_k = _query_weighted_quantile( class_scores, self.alpha[k], class_weights ) else: # If no calibration examples, use -inf (include all) t_k = -np.inf t.append(t_k) self.t = torch.tensor(t, device=self.device)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation with prediction set construction. Returns: Dictionary with all results from base model, plus: - y_predset: Boolean tensor indicating which classes are in the prediction set """ pred = self.model(**kwargs) # Construct prediction set by thresholding probabilities pred["y_predset"] = pred["y_prob"] > self.t return pred
if __name__ == "__main__": """ Demonstration of three approaches for covariate shift correction: 1. Embeddings approach: Automatic KDE computation (CoDrug method) 2. Pre-fitted KDEs approach: User provides KDE estimators 3. Custom weights approach: User provides custom importance weights """ from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader from pyhealth.models import SparcNet from pyhealth.tasks import sleep_staging_isruc_fn from pyhealth.trainer import Trainer, get_metrics_fn # Setup data and model sleep_ds = ISRUCDataset("/srv/local/data/trash", dev=True).set_task( sleep_staging_isruc_fn ) train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) model = SparcNet( dataset=sleep_ds, feature_keys=["signal"], label_key="label", mode="multiclass" ) # ... Train the model here ... # Helper function to extract embeddings (mock implementation) def extract_embeddings(model, dataset): """Extract embeddings from model for a dataset.""" # In practice, you would do: # loader = get_dataloader(dataset, batch_size=32, shuffle=False) # all_embs = [] # for batch in loader: # batch['embed'] = True # output = model(**batch) # all_embs.append(output['embed'].cpu().numpy()) # return np.concatenate(all_embs, axis=0) # For demo, return random embeddings n_samples = len(dataset) embedding_dim = 64 return np.random.randn(n_samples, embedding_dim) print("=" * 80) print("APPROACH 1: Embeddings (Automatic KDE - CoDrug Method)") print("=" * 80) print("This approach automatically computes KDEs from embeddings.") print("Best for: When you have model embeddings and want automatic density estimation.\n") # Extract embeddings from calibration and test sets cal_embeddings = extract_embeddings(model, val_data) test_embeddings = extract_embeddings(model, test_data) # Create model and calibrate with embeddings cal_model_1 = CovariateLabel(model, alpha=0.1) cal_model_1.calibrate( cal_dataset=val_data, cal_embeddings=cal_embeddings, test_embeddings=test_embeddings ) # Evaluate test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) y_true, y_prob, _, extra = Trainer(model=cal_model_1).inference( test_dl, additional_outputs=["y_predset"] ) metrics_1 = get_metrics_fn(cal_model_1.mode)( y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], y_predset=extra["y_predset"] ) print(f"Results: {metrics_1}\n") print("=" * 80) print("APPROACH 2: Pre-fitted KDEs") print("=" * 80) print("This approach uses pre-computed KDE estimators.") print("Best for: When you want control over KDE parameters or reuse KDEs.\n") # Fit KDEs separately with custom parameters kde_cal, kde_test = fit_kde( cal_embeddings, test_embeddings, bandwidth=0.5, # Custom bandwidth kernel="rbf" ) # Create model with pre-fitted KDEs cal_model_2 = CovariateLabel( model, alpha=0.1, kde_test=kde_test, kde_cal=kde_cal ) cal_model_2.calibrate( cal_dataset=val_data, cal_embeddings=cal_embeddings # Still need embeddings for likelihood ratio computation ) # Evaluate y_true, y_prob, _, extra = Trainer(model=cal_model_2).inference( test_dl, additional_outputs=["y_predset"] ) metrics_2 = get_metrics_fn(cal_model_2.mode)( y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], y_predset=extra["y_predset"] ) print(f"Results: {metrics_2}\n") print("=" * 80) print("APPROACH 3: Custom Weights") print("=" * 80) print("This approach uses user-provided importance weights.") print("Best for: Alternative covariate shift methods (importance sampling,") print(" propensity scores, discriminator-based, domain-specific).\n") # Compute custom weights using your own method # Examples of custom weight computation: # Option A: Uniform weights (no covariate shift correction) custom_weights = np.ones(len(val_data)) # Option B: Importance sampling weights (mock example) # In practice, you might use: # - Discriminator-based methods # - Propensity score matching # - Domain adaptation techniques # - Energy-based models # custom_weights = compute_importance_weights(val_data, test_data) # Option C: Exponential weights based on distance (mock example) # distances = compute_distribution_distances(val_data, test_data) # custom_weights = np.exp(-distances) print(f"Using custom weights (shape: {custom_weights.shape})") # Create model and calibrate with custom weights cal_model_3 = CovariateLabel(model, alpha=0.1) cal_model_3.calibrate( cal_dataset=val_data, cal_weights=custom_weights # Provide weights directly ) # Evaluate y_true, y_prob, _, extra = Trainer(model=cal_model_3).inference( test_dl, additional_outputs=["y_predset"] ) metrics_3 = get_metrics_fn(cal_model_3.mode)( y_true, y_prob, metrics=["accuracy", "miscoverage_ps"], y_predset=extra["y_predset"] ) print(f"Results: {metrics_3}\n") print("=" * 80) print("SUMMARY") print("=" * 80) print("Approach 1 (Embeddings): ", metrics_1) print("Approach 2 (Pre-fitted KDEs):", metrics_2) print("Approach 3 (Custom Weights): ", metrics_3) print("\nAll three approaches are valid and can be chosen based on your needs!") print("- Use Approach 1 for simplicity with embeddings (CoDrug method)") print("- Use Approach 2 for fine-grained control over KDE parameters") print("- Use Approach 3 for alternative covariate shift correction methods")