Source code for pyhealth.calib.predictionset.label

"""
LABEL: Least ambiguous set-valued classifiers with bounded error levels.

Paper:

    Sadinle, Mauricio, Jing Lei, and Larry Wasserman.
    "Least ambiguous set-valued classifiers with bounded error levels."
    Journal of the American Statistical Association 114, no. 525 (2019): 223-234.

"""

from typing import Dict, Union

import numpy as np
import torch
from torch.utils.data import Subset

from pyhealth.calib.base_classes import SetPredictor
from pyhealth.calib.utils import prepare_numpy_dataset
from pyhealth.models import BaseModel

__all__ = ["LABEL"]


def _query_quantile(scores, alpha):
    scores = np.sort(scores)
    N = len(scores)
    loc = int(np.floor(alpha * (N + 1))) - 1
    return -np.inf if loc == -1 else scores[loc]


[docs]class LABEL(SetPredictor): """LABEL: Least ambiguous set-valued classifiers with bounded error levels. This is a prediction-set constructor for multi-class classification problems. It controls either :math:`\\mathbb{P}\\{Y \\not \\in C(X) | Y=k\\}\\leq \\alpha_k` (when ``alpha`` is an array), or :math:`\\mathbb{P}\\{Y \\not \\in C(X)\\}\\leq \\alpha` (when ``alpha`` is a float). Here, :math:`C(X)` denotes the final prediction set. This is essentially a split conformal prediction method using the predicted scores. Paper: Sadinle, Mauricio, Jing Lei, and Larry Wasserman. "Least ambiguous set-valued classifiers with bounded error levels." Journal of the American Statistical Association 114, no. 525 (2019): 223-234. :param model: A trained base model. :type model: BaseModel :param alpha: Target mis-coverage rate(s). :type alpha: Union[float, np.ndarray] Examples: >>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.predictionset import LABEL >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate the set classifier, with different class-specific mis-coverage rates >>> cal_model = LABEL(model, [0.15, 0.3, 0.15, 0.15, 0.15]) >>> # Note that we used the test set here because ISRUCDataset has relatively few >>> # patients, and calibration set should be different from the validation set >>> # if the latter is used to pick checkpoint. In general, the calibration set >>> # should be something exchangeable with the test set. Please refer to the paper. >>> cal_model.calibrate(cal_dataset=test_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer, get_metrics_fn >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(test_dl, additional_outputs=['y_predset']) >>> print(get_metrics_fn(cal_model.mode)( ... y_true_all, y_prob_all, metrics=['accuracy', 'miscoverage_ps'], ... y_predset=extra_output['y_predset']) ... ) {'accuracy': 0.709843241966832, 'miscoverage_ps': array([0.1499847 , 0.29997638, 0.14993964, 0.14994704, 0.14999252])} """ def __init__( self, model: BaseModel, alpha: Union[float, np.ndarray], debug=False, **kwargs ) -> None: super().__init__(model, **kwargs) if model.mode != "multiclass": raise NotImplementedError() self.mode = self.model.mode # multiclass for param in model.parameters(): param.requires_grad = False self.model.eval() self.device = model.device self.debug = debug if not isinstance(alpha, float): alpha = np.asarray(alpha) self.alpha = alpha self.t = None
[docs] def calibrate(self, cal_dataset: Subset): """Calibrate the thresholds used to construct the prediction set. :param cal_dataset: Calibration set. :type cal_dataset: Subset """ cal_dataset = prepare_numpy_dataset( self.model, cal_dataset, ["y_prob", "y_true"], debug=self.debug ) y_prob = cal_dataset["y_prob"] y_true = cal_dataset["y_true"] N, K = cal_dataset["y_prob"].shape if isinstance(self.alpha, float): t = _query_quantile(y_prob[np.arange(N), y_true], self.alpha) else: t = [ _query_quantile(y_prob[y_true == k, k], self.alpha[k]) for k in range(K) ] self.t = torch.tensor(t, device=self.device)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation (just like the original model). :return: A dictionary with all results from the base model, with the following updates: y_predset: a bool tensor representing the prediction for each class. :rtype: Dict[str, torch.Tensor] """ pred = self.model(**kwargs) pred["y_predset"] = pred["y_prob"] > self.t return pred
if __name__ == "__main__": from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader from pyhealth.models import SparcNet from pyhealth.tasks import sleep_staging_isruc_fn from pyhealth.calib.predictionset import LABEL sleep_ds = ISRUCDataset("/srv/local/data/trash", dev=True).set_task( sleep_staging_isruc_fn ) train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) model = SparcNet( dataset=sleep_ds, feature_keys=["signal"], label_key="label", mode="multiclass" ) # ... Train the model here ... # Calibrate the set classifier, with different class-specific mis-coverage rates cal_model = LABEL(model, [0.15, 0.3, 0.15, 0.15, 0.15]) # Note that I used the test set here because ISRUCDataset has relatively few # patients, and calibration set should be different from the validation set # if the latter is used to pick checkpoint. In general, the calibration set # should be something exchangeable with the test set. Please refer to the paper. cal_model.calibrate(cal_dataset=test_data) # Evaluate from pyhealth.trainer import Trainer, get_metrics_fn test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference( test_dl, additional_outputs=["y_predset"] ) print( get_metrics_fn(cal_model.mode)( y_true_all, y_prob_all, metrics=["accuracy", "miscoverage_ps"], y_predset=extra_output["y_predset"], ) )