"""
Neighborhood Conformal Prediction (NCP).
"""
from typing import Dict, Optional, Union
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import IterableDataset
from pyhealth.calib.base_classes import SetPredictor
from pyhealth.calib.predictionset.base_conformal import _query_weighted_quantile
from pyhealth.calib.utils import extract_embeddings, prepare_numpy_dataset
from pyhealth.models import BaseModel
__all__ = ["NeighborhoodLabel"]
[docs]class NeighborhoodLabel(SetPredictor):
"""Neighborhood Conformal Prediction (NCP) for multiclass classification.
Reference:
Ghosh, S., Belkhouja, T., Yan, Y., & Doppa, J. R. (2023).
Improving Uncertainty Quantification of Deep Classifiers via
Neighborhood Conformal Prediction.
Args:
model: A trained base model that supports embedding extraction
(must support `embed=True` in forward pass).
alpha: Target miscoverage rate; marginal coverage P(Y not in C(X)) <= alpha.
k_neighbors: Number of nearest calibration neighbors. Default 50.
lambda_L: Temperature for exponential weights; smaller => more localization.
Default 100.0.
debug: If True, process fewer samples for faster iteration.
Examples:
>>> from pyhealth.datasets import TUEVDataset, split_by_sample_conformal
>>> from pyhealth.datasets import get_dataloader
>>> from pyhealth.models import ContraWR
>>> from pyhealth.tasks import EEGEventsTUEV
>>> from pyhealth.calib.predictionset.cluster import NeighborhoodLabel
>>> from pyhealth.calib.utils import extract_embeddings
>>> from pyhealth.trainer import Trainer, get_metrics_fn
>>>
>>> dataset = TUEVDataset(root="path/to/tuev")
>>> sample_dataset = dataset.set_task(EEGEventsTUEV())
>>> train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal(
... sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15], seed=42
... )
>>> model = ContraWR(dataset=sample_dataset)
>>> cal_embeddings = extract_embeddings(model, cal_ds, batch_size=32)
>>> ncp = NeighborhoodLabel(model=model, alpha=0.1, k_neighbors=50)
>>> ncp.calibrate(cal_dataset=cal_ds, cal_embeddings=cal_embeddings)
>>> test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
>>> y_true, y_prob, _, extra = Trainer(model=ncp).inference(
... test_loader, additional_outputs=["y_predset"]
... )
>>> metrics = get_metrics_fn(ncp.mode)(
... y_true, y_prob, metrics=["accuracy", "miscoverage_ps"],
... y_predset=extra["y_predset"]
... )
"""
[docs] def __init__(
self,
model: BaseModel,
alpha: float,
k_neighbors: int = 50,
lambda_L: float = 100.0,
debug: bool = False,
**kwargs,
) -> None:
super().__init__(model, **kwargs)
if model.mode != "multiclass":
raise NotImplementedError(
"NeighborhoodLabel only supports multiclass classification"
)
self.mode = self.model.mode
for param in model.parameters():
param.requires_grad = False
self.model.eval()
self.device = model.device
self.debug = debug
if not (0.0 < alpha < 1.0):
raise ValueError(f"alpha must be in (0, 1), got {alpha!r}")
self.alpha = float(alpha)
if not isinstance(k_neighbors, int) or k_neighbors <= 0:
raise ValueError(
f"k_neighbors must be a positive integer, got {k_neighbors!r}"
)
self.k_neighbors = k_neighbors
self.lambda_L = float(lambda_L)
self.cal_embeddings_ = None
self.cal_conformity_scores_ = None
self.alpha_tilde_ = None
self._nn = None
[docs] def calibrate(
self,
cal_dataset: IterableDataset,
cal_embeddings: Optional[np.ndarray] = None,
batch_size: int = 32,
) -> None:
"""Calibrate NCP steps:
Step 1: For each calibration point i, compute Q̃^NCP (weighted quantile)
over its k-NN in calibration using weights.
Step 2: Find ã^NCP(α) = largest ã such that empirical coverage on the
calibration set is >= 1-α; store as alpha_tilde_ for use at test time.
Args:
cal_dataset: Calibration dataset (for labels and predictions if
cal_embeddings not provided).
cal_embeddings: Optional precomputed calibration embeddings
(n_cal, embedding_dim). If None, extracted from cal_dataset.
batch_size: Batch size for embedding extraction when cal_embeddings
is not provided.
"""
cal_dict = prepare_numpy_dataset(
self.model,
cal_dataset,
["y_prob", "y_true"],
debug=self.debug,
)
y_prob = cal_dict["y_prob"]
y_true = cal_dict["y_true"]
N = y_prob.shape[0]
if cal_embeddings is None:
cal_embeddings = extract_embeddings(
self.model, cal_dataset, batch_size=batch_size, device=self.device
)
else:
cal_embeddings = np.asarray(cal_embeddings)
if cal_embeddings.shape[0] != N:
raise ValueError(
f"cal_embeddings length {cal_embeddings.shape[0]} must match "
f"cal_dataset size {N}"
)
conformity_scores = y_prob[np.arange(N), y_true]
k = min(self.k_neighbors, N)
self._nn = NearestNeighbors(n_neighbors=k, metric="euclidean").fit(
np.atleast_2d(cal_embeddings)
)
self.cal_embeddings_ = np.atleast_2d(cal_embeddings)
self.cal_conformity_scores_ = np.asarray(conformity_scores, dtype=np.float64)
# this is the ncp calibration step
distances_cal, indices_cal = self._nn.kneighbors(
self.cal_embeddings_, n_neighbors=k
)
cal_weights = np.exp(-distances_cal / self.lambda_L)
cal_weights = cal_weights / cal_weights.sum(axis=1, keepdims=True)
def _empirical_coverage(alpha_tilde_cand: float) -> float:
t_all = np.zeros(N, dtype=np.float64)
for i in range(N):
t_all[i] = _query_weighted_quantile(
self.cal_conformity_scores_[indices_cal[i]],
alpha_tilde_cand,
cal_weights[i],
)
return float(np.mean(self.cal_conformity_scores_ >= t_all))
low, high = 0.0, 1.0
for _ in range(50):
mid = (low + high) / 2
if _empirical_coverage(mid) >= 1.0 - self.alpha:
low = mid
else:
high = mid
self.alpha_tilde_ = float(low)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward with NCP: per-sample weighted quantile threshold."""
if (
self.cal_embeddings_ is None
or self.cal_conformity_scores_ is None
or self.alpha_tilde_ is None
):
raise RuntimeError(
"NeighborhoodLabel must be calibrated before inference. "
"Call calibrate() first."
)
pred = self.model(**{**kwargs, "embed": True})
if "embed" not in pred:
raise ValueError(
f"Model {type(self.model).__name__} does not return "
"embeddings. Ensure it supports embed=True in forward()."
)
test_emb = pred["embed"].detach().cpu().numpy()
test_emb = np.atleast_2d(test_emb)
batch_size = test_emb.shape[0]
n_cal = self.cal_conformity_scores_.shape[0]
k = min(self.k_neighbors, n_cal)
distances, indices = self._nn.kneighbors(test_emb, n_neighbors=k)
thresholds = np.zeros(batch_size, dtype=np.float64)
for i in range(batch_size):
w = np.exp(-distances[i] / self.lambda_L)
w = w / np.sum(w)
scores_i = self.cal_conformity_scores_[indices[i]]
thresholds[i] = _query_weighted_quantile(
scores_i, self.alpha_tilde_, w
)
th = torch.as_tensor(
thresholds, device=self.device, dtype=pred["y_prob"].dtype
)
if pred["y_prob"].ndim > 1:
th = th.view(-1, *([1] * (pred["y_prob"].ndim - 1)))
y_predset = pred["y_prob"] >= th
# if threshold is high, include at least argmax
empty = y_predset.sum(dim=1) == 0
if empty.any():
argmax_idx = pred["y_prob"].argmax(dim=1)
y_predset[empty, argmax_idx[empty]] = True
pred["y_predset"] = y_predset
pred.pop("embed", None)
return pred