Source code for pyhealth.calib.predictionset.scrib

"""
SCRIB: Set-classifier with Class-specific Risk Bounds

Implementation based on https://github.com/zlin7/scrib

"""

import time
from typing import Dict, Union

import numpy as np
import pandas as pd
import torch

from pyhealth.calib.base_classes import SetPredictor
from pyhealth.calib.utils import prepare_numpy_dataset
from pyhealth.models import BaseModel

from . import quicksearch as qs

OVERALL_LOSSFUNC = "overall"
CLASSPECIFIC_LOSSFUNC = "classspec"

__all__ = ["SCRIB"]


class _CoordDescent:
    def __init__(
        self,
        model_output,
        labels,
        rks,
        loss_func=OVERALL_LOSSFUNC,
        loss_kwargs=None,
        restart_n=1000,
        restart_range=0.1,
        init_range=None,
        verbose=False,
    ):
        self.N, self.K = model_output.shape

        # quantities useful for loss eval
        self.loss_name = loss_func
        if loss_kwargs is None:
            loss_kwargs = {}
        self.loss_kwargs = loss_kwargs
        if self.loss_name == OVERALL_LOSSFUNC:
            assert isinstance(rks, float)
        elif rks is not None:
            rks = np.asarray(rks)
        self.idx2rnk = np.asarray(
            pd.DataFrame(model_output).rank(ascending=True), np.int32
        )
        if np.min(self.idx2rnk) == 1:
            self.idx2rnk -= 1
        self.rnk2idx = np.asarray(np.argsort(model_output, axis=0), np.int32)
        if len(labels.shape) == 2:
            # one-hot -> class indices
            labels = np.argmax(labels, 1)
        self.labels = np.asarray(labels, np.int32)
        self.max_classes = np.argmax(model_output, 1)
        self.rks = rks

        self.model_output = model_output
        self.restart_n = restart_n
        self.restart_range = restart_range
        self.init_range = init_range or (int(np.ceil(self.N / 2)), self.N - 1)
        self.verbose = verbose

    def _search(self, ps):
        _search_fn = {
            CLASSPECIFIC_LOSSFUNC: qs.coord_desc_classspecific,
            OVERALL_LOSSFUNC: qs.coord_desc_overall,
        }[self.loss_name]
        return _search_fn(
            self.idx2rnk,
            self.rnk2idx,
            self.labels,
            self.max_classes,
            ps,
            self.rks,
            **self.loss_kwargs,
        )

    def _loss_eval(self, ps):
        _loss_fn = {
            CLASSPECIFIC_LOSSFUNC: qs.loss_classspecific,
            OVERALL_LOSSFUNC: qs.loss_overall,
        }[self.loss_name]
        return _loss_fn(
            self.idx2rnk,
            self.rnk2idx,
            self.labels,
            self.max_classes,
            ps,
            self.rks,
            **self.loss_kwargs,
        )

    def _p2t(self, p):
        # Translate ranks to thresholds
        return [self.model_output[self.rnk2idx[p[k], k], k] for k in range(self.K)]

    def _sample_new_loc(self, old_p, restart_range=0.1):
        diff = np.random.uniform(-restart_range, restart_range, self.K)
        new_p = old_p.copy()
        for k in range(self.K):
            new_p[k] = max(min(int(new_p[k] + diff[k] * self.N), self.N - 1), 0)
        return new_p

    def search_once(self, seed=7):
        def print_(s):
            if self.verbose:
                print(s)

        np.random.seed(seed)
        best_ps = np.random.randint(*self.init_range, self.K)

        st = time.time()
        best_loss, best_ps, _ = self._search(best_ps)
        ed1 = time.time()
        if self.restart_n > 0:
            keep_going = True
            while keep_going:
                keep_going = False
                curr_restart_best_loss, curr_restart_best_ps = np.inf, None

                for _ in range(self.restart_n):
                    # Restart in neighborhood
                    new_ps_ = self._sample_new_loc(best_ps, self.restart_range)
                    loss_ = self._loss_eval(new_ps_)
                    if loss_ < best_loss:
                        print_(
                            "Neighborhood has a better loc with "
                            f"loss={loss_} < {best_loss}"
                        )
                        best_loss, best_ps, _ = self._search(new_ps_)

                        keep_going = True
                        break
                    elif loss_ < curr_restart_best_loss:
                        curr_restart_best_loss, curr_restart_best_ps = loss_, new_ps_
                if not keep_going:
                    print_(
                        f"Tried {curr_restart_best_ps} vs {best_ps}, "
                        f"loss:{curr_restart_best_loss} > {best_loss}"
                    )
        ed2 = time.time()
        print_(f"{ed1-st:.3f} + {ed2-ed1:.3f} seconds")
        return self._p2t(best_ps), best_loss

    @classmethod
    def search(
        cls,
        prob: np.ndarray,
        label: np.ndarray,
        rks: Union[float, np.ndarray],
        loss_func,
        B: int = 10,
        **kwargs,
    ):
        # label is not one-hot
        best_loss, best_ts = np.inf, None
        searcher = cls(prob, label, rks, loss_func=loss_func, **kwargs)
        for seed in range(B):
            np.random.seed(seed)
            ts, _l = searcher.search_once(seed + 1)
            print(f"{seed}: loss={_l}")
            if _l < best_loss:
                best_loss, best_ts = _l, ts
        return best_ts, best_loss


[docs]class SCRIB(SetPredictor): """SCRIB: Set-classifier with Class-specific Risk Bounds This is a prediction-set constructor for multi-class classification problems. SCRIB tries to control class-specific risk while minimizing the ambiguity. To to this, it selects class-specific thresholds for the predictions, on a calibration set. If ``risk`` is a float (say 0.1), SCRIB controls the overall risk: :math:`\\mathbb{P}\\{Y \\not \\in C(X) | |C(X)| = 1\\}\\leq risk`. If ``risk`` is an array (say `np.asarray([0.1] * 5)`), SCRIB controls the class specific risks: :math:`\\mathbb{P}\\{Y \\not \\in C(X) | Y=k \\land |C(X)| = 1\\}\\leq risk_k` Here, :math:`C(X)` denotes the final prediction set. Paper: Lin, Zhen, Lucas Glass, M. Brandon Westover, Cao Xiao, and Jimeng Sun. "SCRIB: Set-classifier with Class-specific Risk Bounds for Blackbox Models." AAAI 2022. Args: model (BaseModel): A trained model. risk (Union[float, np.ndarray]): risk targets. loss_kwargs (dict, optional): Additional loss parameters (including hyperparameters). It could contain the following float/int hyperparameters: lk: The coefficient for the loss term associated with risk violation penalty. The higher the lk, the more penalty on risk violation (likely higher ambiguity). fill_max: Whether to fill the class with max predicted score when no class exceeds the threshold. In other words, if fill_max, the null region will be filled with max-prediction class. Defaults to {'lk': 1e4, 'fill_max': False} fill_max (bool, optional): Whether to fill the empty prediction set with the max-predicted class. Defaults to True. Examples: >>> from pyhealth.data import ISRUCDataset, split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.predictionset import SCRIB >>> from pyhealth.trainer import get_metrics_fn >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate the set classifier, with different class-specific risk targets >>> cal_model = SCRIB(model, [0.2, 0.3, 0.1, 0.2, 0.1]) >>> # Note that we used the test set here because ISRUCDataset has relatively few >>> # patients, and calibration set should be different from the validation set >>> # if the latter is used to pick checkpoint. In general, the calibration set >>> # should be something exchangeable with the test set. Please refer to the paper. >>> cal_model.calibrate(cal_dataset=test_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(test_dl, additional_outputs=['y_predset']) >>> print(get_metrics_fn(cal_model.mode)( ... y_true_all, y_prob_all, metrics=['accuracy', 'error_ps', 'rejection_rate'], ... y_predset=extra_output['y_predset']) ... ) {'accuracy': 0.709843241966832, 'rejection_rate': 0.6381305287631919, 'error_ps': array([0.32161874, 0.36654135, 0.11461734, 0.23728814, 0.14993925])} """ def __init__( self, model: BaseModel, risk: Union[float, np.ndarray], loss_kwargs: dict = None, debug=False, fill_max=True, **kwargs, ) -> None: super().__init__(model, **kwargs) if model.mode != "multiclass": raise NotImplementedError() self.mode = self.model.mode # multiclass for param in model.parameters(): param.requires_grad = False self.model.eval() self.device = model.device self.debug = debug if isinstance(risk, float): self.loss_name = OVERALL_LOSSFUNC else: risk = np.asarray(risk) self.loss_name = CLASSPECIFIC_LOSSFUNC self.risk = risk if loss_kwargs is None: loss_kwargs = {"lk": 1e4, "fill_max": fill_max} self.loss_kwargs = loss_kwargs self.t = None
[docs] def calibrate(self, cal_dataset): """Calibrate/Search for the thresholds used to construct the prediction set. :param cal_dataset: Calibration set. :type cal_dataset: Subset """ cal_dataset = prepare_numpy_dataset( self.model, cal_dataset, ["y_prob", "y_true"], debug=self.debug ) if self.loss_name == CLASSPECIFIC_LOSSFUNC: assert len(self.risk) == cal_dataset["y_prob"].shape[1] best_ts, _ = _CoordDescent.search( cal_dataset["y_prob"], cal_dataset["y_true"], self.risk, self.loss_name, loss_kwargs=self.loss_kwargs, verbose=self.debug, ) self.t = torch.nn.Parameter(torch.tensor(best_ts, device=self.device))
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation (just like the original model). :return: A dictionary with all results from the base model, with the following updates: y_predset: a bool tensor representing the prediction for each class. :rtype: Dict[str, torch.Tensor] """ ret = self.model(**kwargs) ret["y_predset"] = ret["y_prob"] > self.t return ret
if __name__ == "__main__": from pyhealth.calib.predictionset import SCRIB from pyhealth.datasets import (ISRUCDataset, get_dataloader, split_by_patient) from pyhealth.models import SparcNet from pyhealth.tasks import sleep_staging_isruc_fn from pyhealth.trainer import get_metrics_fn sleep_ds = ISRUCDataset("/srv/local/data/trash", dev=True).set_task( sleep_staging_isruc_fn ) train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) model = SparcNet( dataset=sleep_ds, feature_keys=["signal"], label_key="label", mode="multiclass", ) # ... Train the model here ... # Calibrate the set classifier, with different class-specific risk targets cal_model = SCRIB(model, [0.2, 0.3, 0.1, 0.2, 0.1]) # Note that I used the test set here because ISRUCDataset has relatively few # patients, and calibration set should be different from the validation set # if the latter is used to pick checkpoint. In general, the calibration set # should be something exchangeable with the test set. Please refer to the paper. cal_model.calibrate(cal_dataset=test_data) # Evaluate from pyhealth.trainer import Trainer test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference( test_dl, additional_outputs=["y_predset"] ) print( get_metrics_fn(cal_model.mode)( y_true_all, y_prob_all, metrics=["accuracy", "error_ps", "rejection_rate"], y_predset=extra_output["y_predset"], ) )