"""
Cluster-Based Conformal Prediction.
This module implements conformal prediction with cluster-specific calibration
thresholds using K-means clustering on patient embeddings. The method groups
similar patients into clusters and computes separate calibration thresholds
for each cluster, enabling cluster-aware prediction sets.
This serves as a baseline approach for future personalized/dynamic conformal
prediction methods that use patient similarity for calibration set construction.
"""
from typing import Dict, Optional, Union
import numpy as np
import torch
from sklearn.cluster import KMeans
from torch.utils.data import IterableDataset
from pyhealth.calib.base_classes import SetPredictor
from pyhealth.calib.predictionset.base_conformal import _query_quantile
from pyhealth.calib.utils import extract_embeddings, prepare_numpy_dataset
from pyhealth.models import BaseModel
__all__ = ["ClusterLabel"]
[docs]class ClusterLabel(SetPredictor):
"""Cluster-based conformal prediction for multiclass classification.
This method uses K-means clustering on patient embeddings to group similar
patients into clusters. Each cluster gets its own calibration threshold,
computed from the conformity scores of calibration samples in that cluster.
At inference time, test samples are assigned to their nearest cluster and
use the cluster-specific threshold.
This approach is simpler than KDE-based methods and serves as a baseline
for more advanced personalized conformal prediction approaches.
Args:
model: A trained base model that supports embedding extraction
(must support `embed=True` in forward pass)
alpha: Target miscoverage rate(s). Can be:
- float: marginal coverage P(Y not in C(X)) <= alpha
- array: class-conditional P(Y not in C(X) | Y=k) <= alpha[k]
n_clusters: Number of K-means clusters. Default is 5.
random_state: Random seed for K-means clustering. Default is 42.
debug: Whether to use debug mode (processes fewer samples for
faster iteration)
Examples:
>>> from pyhealth.datasets import TUEVDataset, split_by_sample_conformal
>>> from pyhealth.datasets import get_dataloader
>>> from pyhealth.models import ContraWR
>>> from pyhealth.tasks import EEGEventsTUEV
>>> from pyhealth.calib.predictionset.cluster import ClusterLabel
>>> from pyhealth.calib.utils import extract_embeddings
>>> from pyhealth.trainer import Trainer, get_metrics_fn
>>>
>>> # Prepare data
>>> dataset = TUEVDataset(root="path/to/tuev")
>>> sample_dataset = dataset.set_task(EEGEventsTUEV())
>>> train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal(
... sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15], seed=42
... )
>>>
>>> # Train model
>>> model = ContraWR(dataset=sample_dataset)
>>> # ... training code ...
>>>
>>> # Extract embeddings for clustering
>>> train_embeddings = extract_embeddings(model, train_ds, batch_size=32)
>>> cal_embeddings = extract_embeddings(model, cal_ds, batch_size=32)
>>>
>>> # Create and calibrate cluster-based predictor
>>> cluster_predictor = ClusterLabel(model=model, alpha=0.1, n_clusters=5)
>>> cluster_predictor.calibrate(
... cal_dataset=cal_ds,
... train_embeddings=train_embeddings,
... cal_embeddings=cal_embeddings,
... )
>>>
>>> # Evaluate
>>> test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
>>> y_true, y_prob, _, extra = Trainer(model=cluster_predictor).inference(
... test_loader, additional_outputs=["y_predset"]
... )
>>> metrics = get_metrics_fn(cluster_predictor.mode)(
... y_true, y_prob, metrics=["accuracy", "miscoverage_ps"],
... y_predset=extra["y_predset"]
... )
"""
[docs] def __init__(
self,
model: BaseModel,
alpha: Union[float, np.ndarray],
n_clusters: int = 5,
random_state: int = 42,
debug: bool = False,
**kwargs,
) -> None:
super().__init__(model, **kwargs)
if model.mode != "multiclass":
raise NotImplementedError(
"ClusterLabel only supports multiclass classification"
)
self.mode = self.model.mode
# Freeze model parameters
for param in model.parameters():
param.requires_grad = False
self.model.eval()
self.device = model.device
self.debug = debug
# Store alpha
if not isinstance(alpha, float):
alpha = np.asarray(alpha)
self.alpha = alpha
# Store clustering parameters
if not isinstance(n_clusters, int) or n_clusters <= 0:
raise ValueError(
f"n_clusters must be a positive integer, got {n_clusters!r}"
)
self.n_clusters = n_clusters
self.random_state = random_state
# Will be set during calibration
self.kmeans_model = None
self.cluster_thresholds = None # Dict mapping cluster_id -> threshold(s)
[docs] def calibrate(
self,
cal_dataset: IterableDataset,
train_embeddings: Optional[np.ndarray] = None,
cal_embeddings: Optional[np.ndarray] = None,
batch_size: int = 32,
):
"""Calibrate cluster-specific thresholds.
This method:
1. Combines train and calibration embeddings for clustering
2. Fits K-means on the combined embeddings
3. Assigns calibration samples to clusters
4. Computes cluster-specific calibration thresholds
Args:
cal_dataset: Calibration set
train_embeddings: Optional pre-computed training embeddings
of shape (n_train, embedding_dim). If not provided, will
be extracted from the model (requires train_dataset parameter).
cal_embeddings: Optional pre-computed calibration embeddings
of shape (n_cal, embedding_dim). If not provided, will be
extracted from cal_dataset.
batch_size: Batch size for embedding extraction when
cal_embeddings is not provided. Default is 32.
Note:
Either provide embeddings directly or ensure the model supports
embedding extraction via `embed=True` flag.
"""
# Get predictions and true labels from calibration set
cal_dataset_dict = prepare_numpy_dataset(
self.model,
cal_dataset,
["y_prob", "y_true"],
debug=self.debug,
)
y_prob = cal_dataset_dict["y_prob"]
y_true = cal_dataset_dict["y_true"]
N, K = y_prob.shape
# Extract embeddings if not provided
if cal_embeddings is None:
print("Extracting embeddings from calibration set...")
cal_embeddings = extract_embeddings(
self.model, cal_dataset, batch_size=batch_size, device=self.device
)
else:
cal_embeddings = np.asarray(cal_embeddings)
if train_embeddings is None:
raise ValueError(
"train_embeddings must be provided. "
"Extract embeddings from training set using extract_embeddings()."
)
else:
train_embeddings = np.asarray(train_embeddings)
# Combine embeddings for clustering
print(f"Combining embeddings: train={train_embeddings.shape}, cal={cal_embeddings.shape}")
all_embeddings = np.concatenate([train_embeddings, cal_embeddings], axis=0)
print(f"Total embeddings for clustering: {all_embeddings.shape}")
# Fit K-means on combined embeddings
print(f"Fitting K-means with {self.n_clusters} clusters...")
self.kmeans_model = KMeans(
n_clusters=self.n_clusters,
random_state=self.random_state,
n_init=10,
)
self.kmeans_model.fit(all_embeddings)
# Assign calibration samples to clusters
# Note: cal_embeddings start at index len(train_embeddings) in all_embeddings
cal_start_idx = len(train_embeddings)
cal_cluster_labels = self.kmeans_model.labels_[cal_start_idx:]
print(f"Cluster assignments: {np.bincount(cal_cluster_labels)}")
# Compute non-conformity scores (higher = less conforming)
conformity_scores = 1.0 - y_prob[np.arange(N), y_true]
# Compute cluster-specific thresholds
self.cluster_thresholds = {}
for cluster_id in range(self.n_clusters):
cluster_mask = cal_cluster_labels == cluster_id
cluster_scores = conformity_scores[cluster_mask]
if len(cluster_scores) == 0:
print(
f"Warning: No calibration samples in cluster {cluster_id}, "
"using +inf NC threshold (include all classes)"
)
if isinstance(self.alpha, float):
self.cluster_thresholds[cluster_id] = np.inf
else:
self.cluster_thresholds[cluster_id] = np.array(
[np.inf] * K
)
else:
if isinstance(self.alpha, float):
# Marginal coverage: single threshold per cluster
t = _query_quantile(cluster_scores, self.alpha)
self.cluster_thresholds[cluster_id] = t
else:
# Class-conditional: one threshold per class per cluster
if len(self.alpha) != K:
raise ValueError(
f"alpha must have length {K} for class-conditional "
f"coverage, got {len(self.alpha)}"
)
t = []
for k in range(K):
class_mask = (y_true[cluster_mask] == k)
if np.sum(class_mask) > 0:
class_scores = cluster_scores[class_mask]
t_k = _query_quantile(class_scores, self.alpha[k])
else:
# No examples for this class in cluster: include always
print(
f"Warning: No calibration examples for class {k} "
f"in cluster {cluster_id}, using +inf threshold"
)
t_k = np.inf
t.append(t_k)
self.cluster_thresholds[cluster_id] = np.array(t)
if self.debug:
print(f"Cluster thresholds: {self.cluster_thresholds}")
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation with cluster-specific prediction set construction.
Returns:
Dictionary with all results from base model, plus:
- y_predset: Boolean tensor indicating which classes
are in the prediction set
"""
if self.kmeans_model is None or self.cluster_thresholds is None:
raise RuntimeError(
"Model must be calibrated before inference. "
"Call calibrate() first."
)
# Single forward pass with embed=True to get both predictions and
# embeddings (avoids double compute)
pred = self.model(**{**kwargs, "embed": True})
if "embed" not in pred:
raise ValueError(
f"Model {type(self.model).__name__} does not return "
"embeddings. Make sure the model supports the "
"embed=True flag in its forward() method."
)
# Ensure embeddings are 2D (batch_size, embedding_dim)
sample_embedding = pred["embed"].detach().cpu().numpy()
sample_embedding = np.atleast_2d(sample_embedding)
# Predict cluster for each sample in the batch
cluster_ids = self.kmeans_model.predict(sample_embedding)
# Get cluster-specific threshold for each sample
cluster_thresholds = np.array(
[self.cluster_thresholds[cid] for cid in cluster_ids]
)
cluster_thresholds = torch.as_tensor(
cluster_thresholds, device=self.device, dtype=pred["y_prob"].dtype
)
# Broadcast thresholds to match y_prob shape (batch_size, n_classes).
# Marginal: thresholds are (batch_size,) -> view to (batch_size, 1, ...).
# Class-conditional: thresholds are already (batch_size, K), no view.
if pred["y_prob"].ndim > 1 and cluster_thresholds.ndim == 1:
view_shape = (cluster_thresholds.shape[0],) + (1,) * (
pred["y_prob"].ndim - 1
)
cluster_thresholds = cluster_thresholds.view(view_shape)
# Include class y if its NC score (1 - p(y)) <= NC threshold
pred["y_predset"] = (1.0 - pred["y_prob"]) <= cluster_thresholds
pred.pop("embed", None) # do not expose internal embedding to caller
return pred
if __name__ == "__main__":
"""
Demonstration of cluster-based conformal prediction.
"""
from pyhealth.datasets import TUEVDataset, split_by_sample_conformal, get_dataloader
from pyhealth.models import ContraWR
from pyhealth.tasks import EEGEventsTUEV
from pyhealth.calib.predictionset.cluster import ClusterLabel
from pyhealth.calib.utils import extract_embeddings
from pyhealth.trainer import Trainer, get_metrics_fn
# Setup data and model
dataset = TUEVDataset(root="downloads/tuev/v2.0.1/edf", subset="both")
sample_dataset = dataset.set_task(EEGEventsTUEV())
train_ds, val_ds, cal_ds, test_ds = split_by_sample_conformal(
sample_dataset, ratios=[0.6, 0.1, 0.15, 0.15], seed=42
)
model = ContraWR(dataset=sample_dataset)
# ... Train the model here ...
# Extract embeddings
train_embeddings = extract_embeddings(model, train_ds, batch_size=32)
cal_embeddings = extract_embeddings(model, cal_ds, batch_size=32)
# Create and calibrate cluster-based predictor
cluster_predictor = ClusterLabel(model=model, alpha=0.1, n_clusters=5)
cluster_predictor.calibrate(
cal_dataset=cal_ds,
train_embeddings=train_embeddings,
cal_embeddings=cal_embeddings,
)
# Evaluate
test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
y_true, y_prob, _, extra = Trainer(model=cluster_predictor).inference(
test_loader, additional_outputs=["y_predset"]
)
metrics = get_metrics_fn(cluster_predictor.mode)(
y_true, y_prob, metrics=["accuracy", "miscoverage_ps"],
y_predset=extra["y_predset"]
)
print(f"Results: {metrics}")