pyhealth.calib.calibration#
Model calibration methods
- class pyhealth.calib.calibration.DirichletCalibration(model, debug=False, **kwargs)[source]#
Bases:
PostHocCalibrator
Dirichlet Calibration
Dirichlet calibration is similar to retraining a linear layer mapping from the old logits to the new logits with regularizations. This is a calibration method for multiclass classification only.
Paper:
[1] Kull, Meelis, Miquel Perello Nieto, Markus Kängsepp, Telmo Silva Filho, Hao Song, and Peter Flach. “Beyond temperature scaling: Obtaining well-calibrated multi-class probabilities with dirichlet calibration.” Advances in neural information processing systems 32 (2019).
- Parameters:
model (BaseModel) – A trained base model.
Examples
>>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.calibration import DirichletCalibration >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate >>> cal_model = DirichletCalibration(model) >>> cal_model.calibrate(cal_dataset=val_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl)) {'accuracy': 0.7096615988229524, 'cwECEt_adapt': 0.05336195546573208}
- calibrate(cal_dataset, lr=0.01, max_iter=128, reg_lambda=0.001)[source]#
Calibrate the base model using a calibration dataset.
- Parameters:
- Returns:
None
- Return type:
None
- forward(**kwargs)[source]#
Forward propagation (just like the original model).
- Parameters:
**kwargs –
Additional arguments to the base model.
- Returns:
A dictionary with all results from the base model, with the following modified:
y_prob
: calibrated predicted probabilities.loss
: Cross entropy loss with the new y_prob.logit
: temperature-scaled logits.- Return type:
Dict[str, torch.Tensor]
- class pyhealth.calib.calibration.HistogramBinning(model, debug=False, **kwargs)[source]#
Bases:
PostHocCalibrator
Histogram Binning
Histogram binning amounts to creating bins and computing the accuracy for each bin using the calibration dataset, and then predicting such at test time. For multilabel/binary/multiclass classification tasks, we calibrate each class independently following [1]. Users could choose to renormalize the probability scores for multiclass tasks so they sum to 1.
Paper:
[1] Gupta, Chirag, and Aaditya Ramdas. “Top-label calibration and multiclass-to-binary reductions.” ICLR 2022.
[2] Zadrozny, Bianca, and Charles Elkan. “Learning and making decisions when costs and probabilities are both unknown.” In Proceedings of the seventh ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 204-213. 2001.
- Parameters:
model (BaseModel) – A trained base model.
Examples
>>> from pyhealth.datasets import ISRUCDataset, get_dataloader, split_by_patient >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.calibration import HistogramBinning >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate >>> cal_model = HistogramBinning(model) >>> cal_model.calibrate(cal_dataset=val_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl)) {'accuracy': 0.7189072348464207, 'cwECEt_adapt': 0.04455814993598299}
- calibrate(cal_dataset, nbins=15)[source]#
Calibrate the base model using a calibration dataset.
- Parameters:
cal_dataset (Subset) – Calibration set.
nbins (int, optional) – number of bins to use, defaults to 15
- forward(normalization='sum', **kwargs)[source]#
Forward propagation (just like the original model).
- Parameters:
normalization (str, optional) – how to normalize the calibrated probability. Defaults to ‘sum’ (and only ‘sum’ is supported for now).
**kwargs –
Additional arguments to the base model.
- Returns:
A dictionary with all results from the base model, with the following modified:
y_prob
: calibrated predicted probabilities.loss
: Cross entropy loss with the new y_prob.- Return type:
Dict[str, torch.Tensor]
- class pyhealth.calib.calibration.KCal(model, debug=False, **kwargs)[source]#
Bases:
PostHocCalibrator
Kernel-based Calibration. This is a full calibration method for multiclass classification. It tries to calibrate the predicted probabilities for all classes, by using KDE classifiers estimated from the calibration set.
Paper:
Lin, Zhen, Shubhendu Trivedi, and Jimeng Sun. “Taking a Step Back with KCal: Multi-Class Kernel-Based Calibration for Deep Neural Networks.” ICLR 2023.
- Parameters:
model (BaseModel) – A trained model.
Examples
>>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.calibration import KCal >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate >>> cal_model = KCal(model) >>> cal_model.calibrate(cal_dataset=val_data) >>> # Alternatively, you could re-fit the reprojection: >>> # cal_model.calibrate(cal_dataset=val_data, train_dataset=train_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl)) {'accuracy': 0.7303689172252193, 'cwECEt_adapt': 0.03324275630220515}
- fit(train_dataset, val_dataset=None, split_by_patient=False, dim=32, bs_pred=64, bs_supp=20, epoch_len=5000, epochs=10, load_best_model_at_last=False, **train_kwargs)[source]#
Fit the reprojection module. You don’t need to call this function - it is called in
KCal.calibrate()
. For training details, please refer to the paper.- Parameters:
train_dataset (Dataset) – The training dataset.
val_dataset (Dataset, optional) – The validation dataset. Defaults to None.
split_by_patient (bool, optional) – Whether to split the dataset by patient during training. Defaults to False.
dim (int, optional) – The dimension of the embedding. Defaults to 32.
bs_pred (int, optional) – The batch size for the prediction set. Defaults to 64.
bs_supp (int, optional) – The batch size for the support set. Defaults to 20.
epoch_len (int, optional) – The number of batches in an epoch. Defaults to 5000.
epochs (int, optional) – The number of epochs. Defaults to 10.
load_best_model_at_last (bool, optional) – Whether to load the best model (or the last model). Defaults to False.
**train_kwargs – Other keyword arguments for
pyhealth.trainer.Trainer.train()
.
- calibrate(cal_dataset, num_fold=20, record_id_name=None, train_dataset=None, train_split_by_patient=False, load_best_model_at_last=True, **train_kwargs)[source]#
Calibrate using a calibration dataset. If
train_dataset
is not None, it will be used to fit a re-projection from the base model embeddings. In either case, the calibration set will be used to construct the KDE classifier.- Parameters:
cal_dataset (Subset) – Calibration set.
record_id_name (str, optional) – the key/name of the unique index for records. Defaults to None.
train_dataset (Subset, optional) – Dataset to train the reprojection. Defaults to None (no training).
train_split_by_patient (bool, optional) – Whether to split by patient when training the embeddings. That is, do we use samples from the same patient in KDE during training. Defaults to False.
load_best_model_at_last (bool, optional) – Whether to load the best reprojection basing on the calibration set. Defaults to True.
train_kwargs (dict, optional) – Additional arguments for training the reprojection. Passed to
KCal.fit()
- forward(**kwargs)[source]#
Forward propagation (just like the original model).
- Parameters:
**kwargs –
Additional arguments to the base model.
- Returns:
A dictionary with all results from the base model, with the following modified:
y_prob
: calibrated predicted probabilities.loss
: Cross entropy loss with the new y_prob.- Return type:
Dict[str, torch.Tensor]
- class pyhealth.calib.calibration.TemperatureScaling(model, debug=False, **kwargs)[source]#
Bases:
PostHocCalibrator
Temperature Scaling
Temprature scaling refers to scaling the logits by a “temprature” tuned on the calibration set. For binary classification tasks, this amounts to Platt scaling. For multilabel classification, users can use one temperature for all classes, or one for each. For multiclass classification, this is a confidence calibration method: It tries to calibrate the predicted class’ predicted probability.
Paper:
[1] Guo, Chuan, Geoff Pleiss, Yu Sun, and Kilian Q. Weinberger. “On calibration of modern neural networks.” ICML 2017.
[2] Platt, John. “Probabilistic outputs for support vector machines and comparisons to regularized likelihood methods.” Advances in large margin classifiers 10, no. 3 (1999): 61-74.
- Parameters:
model (BaseModel) – A trained base model.
Examples
>>> from pyhealth.datasets import ISRUCDataset, get_dataloader, split_by_patient >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.calibration import TemperatureScaling >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate >>> cal_model = TemperatureScaling(model) >>> cal_model.calibrate(cal_dataset=val_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl)) {'accuracy': 0.709843241966832, 'cwECEt_adapt': 0.051673596521491505}
- calibrate(cal_dataset, lr=0.01, max_iter=50, mult_temp=False)[source]#
Calibrate the base model using a calibration dataset.
- forward(**kwargs)[source]#
Forward propagation (just like the original model).
- Parameters:
**kwargs –
Additional arguments to the base model.
- Returns:
A dictionary with all results from the base model, with the following modified:
y_prob
: calibrated predicted probabilities.loss
: Cross entropy loss with the new y_prob.logit
: temperature-scaled logits.- Return type:
Dict[str, torch.Tensor]