Source code for pyhealth.calib.calibration.kcal

"""KCal: Kernel-based Calibration

Implementation based on https://github.com/zlin7/KCal.

Paper:

    Lin, Zhen, Shubhendu Trivedi, and Jimeng Sun.
    "Taking a Step Back with KCal: Multi-Class Kernel-Based Calibration
    for Deep Neural Networks." ICLR 2023.


"""
from typing import Dict

import torch
from torch.utils.data import DataLoader, Subset

from pyhealth.calib.base_classes import PostHocCalibrator
from pyhealth.calib.utils import prepare_numpy_dataset
from pyhealth.trainer import Trainer

from .bw import fit_bandwidth
from .embed_data import _EmbedData
from .kde import KDE_classification, KDECrossEntropyLoss, RBFKernelMean

__all__ = ["KCal"]


class ProjectionWrap(torch.nn.Module):
    """Base class for reprojections."""

    def __init__(self) -> None:
        super().__init__()
        self.criterion = KDECrossEntropyLoss()
        self.mode = "multiclass"

    def embed(self, x):
        """The actual projection"""
        raise NotImplementedError()

    def _forward(self, data, target=None, device=None):
        device = device or self.fc.weight.device

        data["supp_embed"] = self.embed(data["supp_embed"].to(device))
        data["supp_target"] = data["supp_target"].to(device)
        if target is None:
            # no supp vs pred - LOO prediction (for eval)
            assert "pred_embed" not in data
            data["pred_embed"] = None
            target = data["supp_target"]
            assert not self.training
        else:
            # used for train
            data["pred_embed"] = self.embed(data["pred_embed"].to(device))
            if "weights" in data and isinstance(data["weights"], torch.Tensor):
                data["weights"] = data["weights"].to(device)
        loss = self.criterion(
            data, target.to(device), eval_only=data["pred_embed"] is None
        )
        return {
            "loss": loss["loss"],
            "y_prob": loss["extra_output"]["prediction"],
            "y_true": target,
        }


class Identity(ProjectionWrap):
    """The identity reprojection (no reprojection)."""

    def embed(self, x):
        return x

    def forward(self, data, target=None):
        """Foward operations"""
        return self._forward(data, target, data["supp_embed"].device)


class SkipELU(ProjectionWrap):
    """The default reprojection module with 2 layers and a skip connection."""

    def __init__(self, input_features, output_features):
        super().__init__()
        self.bn = torch.nn.BatchNorm1d(input_features)
        self.mid = torch.nn.Linear(input_features, output_features)
        self.bn2 = torch.nn.BatchNorm1d(output_features)
        self.fc = torch.nn.Linear(output_features, output_features, bias=False)
        self.act = torch.nn.ELU()

    def embed(self, x):
        x = self.mid(self.bn(x))
        ret = self.fc(self.act(x))
        return ret + x

    def forward(self, data, target=None):
        """Foward operations"""
        return self._forward(data, target, self.fc.weight.device)


def _embed_dataset(model, dataset, record_id_name=None, debug=False, batch_size=32):

    ret = prepare_numpy_dataset(
        model,
        dataset,
        ["y_true", "embed"],
        incl_data_keys=["patient_id"]
        + ([] if record_id_name is None else [record_id_name]),
        forward_kwargs={"embed": True},
        debug=debug,
        batch_size=batch_size,
    )
    return {
        "labels": ret["y_true"],
        "indices": ret.get(record_id_name, None),
        "embed": ret["embed"],
        "group": ret["patient_id"],
    }


[docs]class KCal(PostHocCalibrator): """Kernel-based Calibration. This is a *full* calibration method for *multiclass* classification. It tries to calibrate the predicted probabilities for all classes, by using KDE classifiers estimated from the calibration set. Paper: Lin, Zhen, Shubhendu Trivedi, and Jimeng Sun. "Taking a Step Back with KCal: Multi-Class Kernel-Based Calibration for Deep Neural Networks." ICLR 2023. Args: model (BaseModel): A trained model. Examples: >>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.calibration import KCal >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate >>> cal_model = KCal(model) >>> cal_model.calibrate(cal_dataset=val_data) >>> # Alternatively, you could re-fit the reprojection: >>> # cal_model.calibrate(cal_dataset=val_data, train_dataset=train_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl)) {'accuracy': 0.7303689172252193, 'cwECEt_adapt': 0.03324275630220515} """ def __init__(self, model: torch.nn.Module, debug=False, **kwargs) -> None: super().__init__(model, **kwargs) if model.mode != "multiclass": raise NotImplementedError() self.mode = self.model.mode # multiclass self.model.eval() self.device = model.device self.debug = debug self.proj = Identity() self.kern = RBFKernelMean() self.record_id_name = None self.cal_data = {} self.num_classes = None
[docs] def fit( self, train_dataset, val_dataset=None, split_by_patient=False, dim=32, bs_pred=64, bs_supp=20, epoch_len=5000, epochs=10, load_best_model_at_last=False, **train_kwargs ): """Fit the reprojection module. You don't need to call this function - it is called in :func:`KCal.calibrate`. For training details, please refer to the paper. Args: train_dataset (Dataset): The training dataset. val_dataset (Dataset, optional): The validation dataset. Defaults to None. split_by_patient (bool, optional): Whether to split the dataset by patient during training. Defaults to False. dim (int, optional): The dimension of the embedding. Defaults to 32. bs_pred (int, optional): The batch size for the prediction set. Defaults to 64. bs_supp (int, optional): The batch size for the support set. Defaults to 20. epoch_len (int, optional): The number of batches in an epoch. Defaults to 5000. epochs (int, optional): The number of epochs. Defaults to 10. load_best_model_at_last (bool, optional): Whether to load the best model (or the last model). Defaults to False. **train_kwargs: Other keyword arguments for :func:`pyhealth.trainer.Trainer.train`. """ _train_data = _embed_dataset( self.model, train_dataset, self.record_id_name, self.debug ) self.num_classes = max(_train_data["labels"]) + 1 if not split_by_patient: # Allow using other samples from the same patient to make the prediction _train_data.pop("group") _train_data = _EmbedData( bs_pred=bs_pred, bs_supp=bs_supp, epoch_len=epoch_len, **_train_data ) train_loader = DataLoader( _train_data, batch_size=1, collate_fn=_EmbedData._collate_func ) val_loader = None if val_dataset is not None: _val_data = _embed_dataset( self.model, val_dataset, self.record_id_name, self.debug ) _val_data = _EmbedData(epoch_len=1, **_val_data) val_loader = DataLoader( _val_data, batch_size=1, collate_fn=_EmbedData._collate_func ) self.proj = SkipELU(len(_train_data.embed[0]), dim).to(self.device) trainer = Trainer(model=self.proj) trainer.train( train_dataloader=train_loader, val_dataloader=val_loader, epochs=epochs, monitor="loss", monitor_criterion="min", load_best_model_at_last=load_best_model_at_last, **train_kwargs ) self.proj.eval()
[docs] def calibrate( self, cal_dataset: Subset, num_fold=20, record_id_name=None, train_dataset: Subset = None, train_split_by_patient=False, load_best_model_at_last=True, **train_kwargs ): """Calibrate using a calibration dataset. If ``train_dataset`` is not None, it will be used to fit a re-projection from the base model embeddings. In either case, the calibration set will be used to construct the KDE classifier. Args: cal_dataset (Subset): Calibration set. record_id_name (str, optional): the key/name of the unique index for records. Defaults to None. train_dataset (Subset, optional): Dataset to train the reprojection. Defaults to None (no training). train_split_by_patient (bool, optional): Whether to split by patient when training the embeddings. That is, do we use samples from the same patient in KDE during *training*. Defaults to False. load_best_model_at_last (bool, optional): Whether to load the best reprojection basing on the calibration set. Defaults to True. train_kwargs (dict, optional): Additional arguments for training the reprojection. Passed to :func:`KCal.fit` """ self.record_id_name = record_id_name if train_dataset is not None: self.fit( train_dataset, val_dataset=cal_dataset, split_by_patient=train_split_by_patient, load_best_model_at_last=load_best_model_at_last, **train_kwargs ) else: print( "No `train_dataset` - using the raw embeddings from the base classifier." ) _cal_data = _embed_dataset( self.model, cal_dataset, self.record_id_name, self.debug ) if self.num_classes is None: self.num_classes = max(_cal_data["labels"]) + 1 assert ( self.num_classes == max(_cal_data["labels"]) + 1 ), "Train/Calibration data seem to have different classes" self.cal_data["Y"] = torch.tensor( _cal_data["labels"], dtype=torch.long, device=self.device ) self.cal_data["Y"] = torch.nn.functional.one_hot( self.cal_data["Y"], self.num_classes ).float() with torch.no_grad(): self.cal_data["X"] = self.proj.embed( torch.tensor(_cal_data["embed"], dtype=torch.float, device=self.device) ) # Choose bandwidth self.kern.set_bandwidth( fit_bandwidth(group=_cal_data["group"], num_fold=num_fold, **self.cal_data) )
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation (just like the original model). :param **kwargs: Additional arguments to the base model. :return: A dictionary with all results from the base model, with the following modified: ``y_prob``: calibrated predicted probabilities. ``loss``: Cross entropy loss with the new y_prob. :rtype: Dict[str, torch.Tensor] """ ret = self.model(embed=True, **kwargs) X_pred = self.proj.embed(ret.pop("embed")) ret["y_prob"] = KDE_classification( kern=self.kern, X_pred=X_pred, **self.cal_data ) ret["loss"] = self.proj.criterion.log_loss(ret["y_prob"], ret["y_true"]) return ret
if __name__ == "__main__": from pyhealth.calib.calibration import KCal from pyhealth.datasets import (ISRUCDataset, get_dataloader, split_by_patient) from pyhealth.models import SparcNet from pyhealth.tasks import sleep_staging_isruc_fn sleep_ds = ISRUCDataset( root="/srv/local/data/trash/", dev=True, ).set_task(sleep_staging_isruc_fn) train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) model = SparcNet( dataset=sleep_ds, feature_keys=["signal"], label_key="label", mode="multiclass" ) # ... Train the model here ... # Calibrate cal_model = KCal(model) cal_model.calibrate(cal_dataset=val_data) # Evaluate from pyhealth.trainer import Trainer test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) print( Trainer(model=cal_model, metrics=["cwECEt_adapt", "accuracy"]).evaluate(test_dl) )