Source code for pyhealth.calib.calibration.dircal

"""
Dirichlet Calibration.

Implementation based on https://github.com/dirichletcal/dirichlet_python
"""

from typing import Dict

import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Subset

from pyhealth.calib.base_classes import PostHocCalibrator
from pyhealth.calib.utils import prepare_numpy_dataset
from pyhealth.models import BaseModel


def _get_identity_weights(n_classes, method="Full"):
    raw_weights = None
    if (method is None) or (method == "Full"):
        raw_weights = torch.zeros((n_classes, n_classes + 1)) + torch.hstack(
            [torch.eye(n_classes), torch.zeros((n_classes, 1))]
        )
    else:
        raise NotImplementedError
    return raw_weights.ravel()


def _softmax(X):
    """Compute the softmax of matrix X in a numerically stable way."""
    shiftx = X - torch.max(X, axis=1).values.reshape(-1, 1)
    exps = torch.exp(shiftx)
    return exps / torch.sum(exps, axis=1).reshape(-1, 1)


def _get_weights(params, k, ref_row, method):
    """Reshapes the given params (weights) into the full matrix including 0"""

    if method in ["Full", None]:
        raw_weights = params.reshape(-1, k + 1)
    if ref_row:
        weights = raw_weights - torch.repeat_interleave(
            raw_weights[-1, :].reshape(1, -1), k, dim=0
        )
    else:
        weights = raw_weights
    return weights


[docs]class DirichletCalibration(PostHocCalibrator): """Dirichlet Calibration Dirichlet calibration is similar to retraining a linear layer mapping from the old logits to the new logits with regularizations. This is a calibration method for *multiclass* classification only. Paper: [1] Kull, Meelis, Miquel Perello Nieto, Markus Kängsepp, Telmo Silva Filho, Hao Song, and Peter Flach. "Beyond temperature scaling: Obtaining well-calibrated multi-class probabilities with dirichlet calibration." Advances in neural information processing systems 32 (2019). :param model: A trained base model. :type model: BaseModel Examples: >>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.calibration import DirichletCalibration >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate >>> cal_model = DirichletCalibration(model) >>> cal_model.calibrate(cal_dataset=val_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl)) {'accuracy': 0.7096615988229524, 'cwECEt_adapt': 0.05336195546573208} """ def __init__(self, model: BaseModel, debug=False, **kwargs) -> None: super().__init__(model, **kwargs) self.mode = self.model.mode for param in model.parameters(): param.requires_grad = False self.model.eval() self.device = model.device self.debug = debug self.num_classes = None def _forward(self, logits): logits = torch.hstack( [logits, torch.zeros((logits.shape[0], 1), device=self.device)] ) weights = _get_weights( self.weights, k=self.num_classes, ref_row=True, method="Full" ) new_logits = torch.matmul(logits, weights.permute(1, 0)) return new_logits, weights
[docs] def calibrate(self, cal_dataset: Subset, lr=0.01, max_iter=128, reg_lambda=1e-3): """Calibrate the base model using a calibration dataset. :param cal_dataset: Calibration set. :type cal_dataset: Subset :param lr: learning rate, defaults to 0.01 :type lr: float, optional :param max_iter: maximum iterations, defaults to 128 :type max_iter: int, optional :param reg_lambda: regularization coefficient on the deviation from identity matrix. defaults to 1e-3 :type reg_lambda: float, optional :return: None :rtype: None """ _cal_data = prepare_numpy_dataset( self.model, cal_dataset, ["y_true", "logit"], debug=self.debug ) if self.num_classes is None: self.num_classes = _cal_data["logit"].shape[1] self.weights = torch.nn.Parameter( torch.tensor(_get_identity_weights(self.num_classes), device=self.device), requires_grad=True, ) optimizer = optim.LBFGS([self.weights], lr=lr, max_iter=max_iter) logits = torch.tensor(_cal_data["logit"], dtype=torch.float, device=self.device) label = torch.tensor( F.one_hot(torch.tensor(_cal_data["y_true"]), num_classes=self.num_classes), dtype=torch.long if self.model.mode == "multiclass" else torch.float32, device=self.device, ) reg_mu = None reg_format = "identity" def _eval(): optimizer.zero_grad() new_logits, weights = self._forward(logits) probs = _softmax(new_logits) loss = torch.mean(-torch.log(torch.sum(label * probs, dim=1))) if reg_mu is None: if reg_format == "identity": reg = torch.hstack( [ torch.eye(self.num_classes), torch.zeros((self.num_classes, 1)), ] ) else: reg = torch.zeros((self.num_classes, self.num_classes + 1)) loss = loss + reg_lambda * torch.sum( (weights - reg.to(self.device)) ** 2 ) else: weights_hat = weights - torch.hstack( [ weights[:, :-1] * torch.eye(self.num_classes, device=self.device), torch.zeros((self.num_classes, 1), device=self.device), ] ) loss = ( loss + reg_lambda * torch.sum(weights_hat[:, :-1] ** 2) + reg_mu * torch.sum(weights_hat[:, -1] ** 2) ) loss.backward() return loss self.train() optimizer.step(_eval) self.eval()
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation (just like the original model). :param **kwargs: Additional arguments to the base model. :return: A dictionary with all results from the base model, with the following modified: ``y_prob``: calibrated predicted probabilities. ``loss``: Cross entropy loss with the new y_prob. ``logit``: temperature-scaled logits. :rtype: Dict[str, torch.Tensor] """ ret = self.model(**kwargs) ret["logit"] = self._forward(ret["logit"])[0] ret["y_prob"] = self.model.prepare_y_prob(ret["logit"]) criterion = self.model.get_loss_function() ret["loss"] = criterion(ret["logit"], ret["y_true"]) return ret
if __name__ == "__main__": from pyhealth.calib.calibration import DirichletCalibration from pyhealth.datasets import (ISRUCDataset, get_dataloader, split_by_patient) from pyhealth.models import SparcNet from pyhealth.tasks import sleep_staging_isruc_fn sleep_ds = ISRUCDataset( root="/srv/local/data/trash/", dev=True, ).set_task(sleep_staging_isruc_fn) train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) model = SparcNet( dataset=sleep_ds, feature_keys=["signal"], label_key="label", mode="multiclass" ) # ... Train the model here ... # Calibrate cal_model = DirichletCalibration(model) cal_model.calibrate(cal_dataset=val_data) # Evaluate from pyhealth.trainer import Trainer test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) print( Trainer(model=cal_model, metrics=["cwECEt_adapt", "accuracy"]).evaluate(test_dl) )