pyhealth.calib.calibration#

Model calibration methods

class pyhealth.calib.calibration.DirichletCalibration(model, debug=False, **kwargs)[source]#

Bases: PostHocCalibrator

Dirichlet Calibration

Dirichlet calibration is similar to retraining a linear layer mapping from the old logits to the new logits with regularizations. This is a calibration method for multiclass classification only.

Paper:

[1] Kull, Meelis, Miquel Perello Nieto, Markus Kängsepp, Telmo Silva Filho, Hao Song, and Peter Flach. “Beyond temperature scaling: Obtaining well-calibrated multi-class probabilities with dirichlet calibration.” Advances in neural information processing systems 32 (2019).

Parameters:

model (BaseModel) – A trained base model.

Examples

>>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader
>>> from pyhealth.models import SparcNet
>>> from pyhealth.tasks import sleep_staging_isruc_fn
>>> from pyhealth.calib.calibration import DirichletCalibration
>>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn)
>>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2])
>>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"],
...     label_key="label", mode="multiclass")
>>> # ... Train the model here ...
>>> # Calibrate
>>> cal_model = DirichletCalibration(model)
>>> cal_model.calibrate(cal_dataset=val_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl))
{'accuracy': 0.7096615988229524, 'cwECEt_adapt': 0.05336195546573208}
calibrate(cal_dataset, lr=0.01, max_iter=128, reg_lambda=0.001)[source]#

Calibrate the base model using a calibration dataset.

Parameters:
  • cal_dataset (Subset) – Calibration set.

  • lr (float, optional) – learning rate, defaults to 0.01

  • max_iter (int, optional) – maximum iterations, defaults to 128

  • reg_lambda (float, optional) – regularization coefficient on the deviation from identity matrix. defaults to 1e-3

Returns:

None

Return type:

None

forward(**kwargs)[source]#

Forward propagation (just like the original model).

Parameters:

**kwargs

Additional arguments to the base model.

Returns:

A dictionary with all results from the base model, with the following modified:

y_prob: calibrated predicted probabilities. loss: Cross entropy loss with the new y_prob. logit: temperature-scaled logits.

Return type:

Dict[str, torch.Tensor]

class pyhealth.calib.calibration.HistogramBinning(model, debug=False, **kwargs)[source]#

Bases: PostHocCalibrator

Histogram Binning

Histogram binning amounts to creating bins and computing the accuracy for each bin using the calibration dataset, and then predicting such at test time. For multilabel/binary/multiclass classification tasks, we calibrate each class independently following [1]. Users could choose to renormalize the probability scores for multiclass tasks so they sum to 1.

Paper:

[1] Gupta, Chirag, and Aaditya Ramdas. “Top-label calibration and multiclass-to-binary reductions.” ICLR 2022.

[2] Zadrozny, Bianca, and Charles Elkan. “Learning and making decisions when costs and probabilities are both unknown.” In Proceedings of the seventh ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 204-213. 2001.

Parameters:

model (BaseModel) – A trained base model.

Examples

>>> from pyhealth.datasets import ISRUCDataset, get_dataloader, split_by_patient
>>> from pyhealth.models import SparcNet
>>> from pyhealth.tasks import sleep_staging_isruc_fn
>>> from pyhealth.calib.calibration import HistogramBinning
>>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn)
>>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2])
>>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"],
...     label_key="label", mode="multiclass")
>>> # ... Train the model here ...
>>> # Calibrate
>>> cal_model = HistogramBinning(model)
>>> cal_model.calibrate(cal_dataset=val_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl))
{'accuracy': 0.7189072348464207, 'cwECEt_adapt': 0.04455814993598299}
calibrate(cal_dataset, nbins=15)[source]#

Calibrate the base model using a calibration dataset.

Parameters:
  • cal_dataset (Subset) – Calibration set.

  • nbins (int, optional) – number of bins to use, defaults to 15

forward(normalization='sum', **kwargs)[source]#

Forward propagation (just like the original model).

Parameters:
  • normalization (str, optional) – how to normalize the calibrated probability. Defaults to ‘sum’ (and only ‘sum’ is supported for now).

  • **kwargs

    Additional arguments to the base model.

Returns:

A dictionary with all results from the base model, with the following modified:

y_prob: calibrated predicted probabilities. loss: Cross entropy loss with the new y_prob.

Return type:

Dict[str, torch.Tensor]

class pyhealth.calib.calibration.KCal(model, debug=False, **kwargs)[source]#

Bases: PostHocCalibrator

Kernel-based Calibration. This is a full calibration method for multiclass classification. It tries to calibrate the predicted probabilities for all classes, by using KDE classifiers estimated from the calibration set.

Paper:

Lin, Zhen, Shubhendu Trivedi, and Jimeng Sun. “Taking a Step Back with KCal: Multi-Class Kernel-Based Calibration for Deep Neural Networks.” ICLR 2023.

Parameters:

model (BaseModel) – A trained model.

Examples

>>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader
>>> from pyhealth.models import SparcNet
>>> from pyhealth.tasks import sleep_staging_isruc_fn
>>> from pyhealth.calib.calibration import KCal
>>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn)
>>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2])
>>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"],
...     label_key="label", mode="multiclass")
>>> # ... Train the model here ...
>>> # Calibrate
>>> cal_model = KCal(model)
>>> cal_model.calibrate(cal_dataset=val_data)
>>> # Alternatively, you could re-fit the reprojection:
>>> # cal_model.calibrate(cal_dataset=val_data, train_dataset=train_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl))
{'accuracy': 0.7303689172252193, 'cwECEt_adapt': 0.03324275630220515}
fit(train_dataset, val_dataset=None, split_by_patient=False, dim=32, bs_pred=64, bs_supp=20, epoch_len=5000, epochs=10, load_best_model_at_last=False, **train_kwargs)[source]#

Fit the reprojection module. You don’t need to call this function - it is called in KCal.calibrate(). For training details, please refer to the paper.

Parameters:
  • train_dataset (Dataset) – The training dataset.

  • val_dataset (Dataset, optional) – The validation dataset. Defaults to None.

  • split_by_patient (bool, optional) – Whether to split the dataset by patient during training. Defaults to False.

  • dim (int, optional) – The dimension of the embedding. Defaults to 32.

  • bs_pred (int, optional) – The batch size for the prediction set. Defaults to 64.

  • bs_supp (int, optional) – The batch size for the support set. Defaults to 20.

  • epoch_len (int, optional) – The number of batches in an epoch. Defaults to 5000.

  • epochs (int, optional) – The number of epochs. Defaults to 10.

  • load_best_model_at_last (bool, optional) – Whether to load the best model (or the last model). Defaults to False.

  • **train_kwargs – Other keyword arguments for pyhealth.trainer.Trainer.train().

calibrate(cal_dataset, num_fold=20, record_id_name=None, train_dataset=None, train_split_by_patient=False, load_best_model_at_last=True, **train_kwargs)[source]#

Calibrate using a calibration dataset. If train_dataset is not None, it will be used to fit a re-projection from the base model embeddings. In either case, the calibration set will be used to construct the KDE classifier.

Parameters:
  • cal_dataset (Subset) – Calibration set.

  • record_id_name (str, optional) – the key/name of the unique index for records. Defaults to None.

  • train_dataset (Subset, optional) – Dataset to train the reprojection. Defaults to None (no training).

  • train_split_by_patient (bool, optional) – Whether to split by patient when training the embeddings. That is, do we use samples from the same patient in KDE during training. Defaults to False.

  • load_best_model_at_last (bool, optional) – Whether to load the best reprojection basing on the calibration set. Defaults to True.

  • train_kwargs (dict, optional) – Additional arguments for training the reprojection. Passed to KCal.fit()

forward(**kwargs)[source]#

Forward propagation (just like the original model).

Parameters:

**kwargs

Additional arguments to the base model.

Returns:

A dictionary with all results from the base model, with the following modified:

y_prob: calibrated predicted probabilities. loss: Cross entropy loss with the new y_prob.

Return type:

Dict[str, torch.Tensor]

class pyhealth.calib.calibration.TemperatureScaling(model, debug=False, **kwargs)[source]#

Bases: PostHocCalibrator

Temperature Scaling

Temprature scaling refers to scaling the logits by a “temprature” tuned on the calibration set. For binary classification tasks, this amounts to Platt scaling. For multilabel classification, users can use one temperature for all classes, or one for each. For multiclass classification, this is a confidence calibration method: It tries to calibrate the predicted class’ predicted probability.

Paper:

[1] Guo, Chuan, Geoff Pleiss, Yu Sun, and Kilian Q. Weinberger. “On calibration of modern neural networks.” ICML 2017.

[2] Platt, John. “Probabilistic outputs for support vector machines and comparisons to regularized likelihood methods.” Advances in large margin classifiers 10, no. 3 (1999): 61-74.

Parameters:

model (BaseModel) – A trained base model.

Examples

>>> from pyhealth.datasets import ISRUCDataset, get_dataloader, split_by_patient
>>> from pyhealth.models import SparcNet
>>> from pyhealth.tasks import sleep_staging_isruc_fn
>>> from pyhealth.calib.calibration import TemperatureScaling
>>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn)
>>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2])
>>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"],
...     label_key="label", mode="multiclass")
>>> # ... Train the model here ...
>>> # Calibrate
>>> cal_model = TemperatureScaling(model)
>>> cal_model.calibrate(cal_dataset=val_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> print(Trainer(model=cal_model, metrics=['cwECEt_adapt', 'accuracy']).evaluate(test_dl))
{'accuracy': 0.709843241966832, 'cwECEt_adapt': 0.051673596521491505}
calibrate(cal_dataset, lr=0.01, max_iter=50, mult_temp=False)[source]#

Calibrate the base model using a calibration dataset.

Parameters:
  • cal_dataset (Subset) – Calibration set.

  • lr (float, optional) – learning rate, defaults to 0.01

  • max_iter (int, optional) – maximum iterations, defaults to 50

  • mult_temp (bool, optional) – if mult_temp and mode=’multilabel’, defaults to False

Returns:

None

Return type:

None

forward(**kwargs)[source]#

Forward propagation (just like the original model).

Parameters:

**kwargs

Additional arguments to the base model.

Returns:

A dictionary with all results from the base model, with the following modified:

y_prob: calibrated predicted probabilities. loss: Cross entropy loss with the new y_prob. logit: temperature-scaled logits.

Return type:

Dict[str, torch.Tensor]