pyhealth.calib.predictionset#

Prediction set construction methods

class pyhealth.calib.predictionset.LABEL(model, alpha, debug=False, **kwargs)[source]#

Bases: SetPredictor

LABEL: Least ambiguous set-valued classifiers with bounded error levels.

This is a prediction-set constructor for multi-class classification problems. It controls either \(\mathbb{P}\{Y \not \in C(X) | Y=k\}\leq \alpha_k\) (when alpha is an array), or \(\mathbb{P}\{Y \not \in C(X)\}\leq \alpha\) (when alpha is a float). Here, \(C(X)\) denotes the final prediction set. This is essentially a split conformal prediction method using the predicted scores.

Paper:

Sadinle, Mauricio, Jing Lei, and Larry Wasserman. “Least ambiguous set-valued classifiers with bounded error levels.” Journal of the American Statistical Association 114, no. 525 (2019): 223-234.

Parameters
  • model (BaseModel) – A trained base model.

  • alpha (Union[float, np.ndarray]) – Target mis-coverage rate(s).

Examples

>>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader
>>> from pyhealth.models import SparcNet
>>> from pyhealth.tasks import sleep_staging_isruc_fn
>>> from pyhealth.calib.predictionset import LABEL
>>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn)
>>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2])
>>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"],
...     label_key="label", mode="multiclass")
>>> # ... Train the model here ...
>>> # Calibrate the set classifier, with different class-specific mis-coverage rates
>>> cal_model = LABEL(model, [0.15, 0.3, 0.15, 0.15, 0.15])
>>> # Note that I used the test set here because ISRUCDataset has relatively few
>>> # patients, and calibration set should be different from the validation set
>>> # if the latter is used to pick checkpoint. In general, the calibration set
>>> # should be something exchangeable with the test set. Please refer to the paper.
>>> cal_model.calibrate(cal_dataset=test_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer, get_metrics_fn
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(test_dl, additional_outputs=['y_predset'])
>>> print(get_metrics_fn(cal_model.mode)(
... y_true_all, y_prob_all, metrics=['accuracy', 'miscoverage_ps'],
... y_predset=extra_output['y_predset'])
... )
{'accuracy': 0.709843241966832, 'miscoverage_ps': array([0.1499847 , 0.29997638, 0.14993964, 0.14994704, 0.14999252])}
calibrate(cal_dataset)[source]#

Calibrate the thresholds used to construct the prediction set.

Parameters

cal_dataset (Subset) – Calibration set.

forward(**kwargs)[source]#

Forward propagation (just like the original model).

Returns

A dictionary with all results from the base model, with the following updates:

y_predset: a bool tensor representing the prediction for each class.

Return type

Dict[str, torch.Tensor]

class pyhealth.calib.predictionset.SCRIB(model, risk, loss_kwargs=None, debug=False, fill_max=True, **kwargs)[source]#

Bases: SetPredictor

SCRIB: Set-classifier with Class-specific Risk Bounds

This is a prediction-set constructor for multi-class classification problems. SCRIB tries to control class-specific risk while minimizing the ambiguity. To to this, it selects class-specific thresholds for the predictions, on a calibration set.

If risk is a float (say 0.1), SCRIB controls the overall risk: \(\mathbb{P}\{Y \not \in C(X) | |C(X)| = 1\}\leq \risk\). If risk is an array (say np.asarray([0.1] * 5)), SCRIB controls the class specific risks: \(\mathbb{P}\{Y \not \in C(X) | Y=k \land |C(X)| = 1\}\leq \risk_k\) Here, \(C(X)\) denotes the final prediction set.

Paper:

Lin, Zhen, Lucas Glass, M. Brandon Westover, Cao Xiao, and Jimeng Sun. “SCRIB: Set-classifier with Class-specific Risk Bounds for Blackbox Models.” AAAI 2022.

Parameters
  • model (BaseModel) – A trained model.

  • risk (Union[float, np.ndarray]) – risk targets.

  • loss_kwargs (dict, optional) –

    Additional loss parameters (including hyperparameters). It could contain the following float/int hyperparameters:

    lk: The coefficient for the loss term associated with risk violation penalty.

    The higher the lk, the more penalty on risk violation (likely higher ambiguity).

    fill_max: Whether to fill the class with max predicted score

    when no class exceeds the threshold. In other words, if fill_max, the null region will be filled with max-prediction class.

    Defaults to {‘lk’: 1e4, ‘fill_max’: False}

  • fill_max (bool, optional) – Whether to fill the empty prediction set with the max-predicted class. Defaults to True.

Examples

>>> from pyhealth.data import ISRUCDataset, split_by_patient, get_dataloader
>>> from pyhealth.models import SparcNet
>>> from pyhealth.tasks import sleep_staging_isruc_fn
>>> from pyhealth.calib.predictionset import SCRIB
>>> from pyhealth.trainer import get_metrics_fn
>>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn)
>>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2])
>>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"],
...     label_key="label", mode="multiclass")
>>> # ... Train the model here ...
>>> # Calibrate the set classifier, with different class-specific risk targets
>>> cal_model = SCRIB(model, [0.2, 0.3, 0.1, 0.2, 0.1])
>>> # Note that I used the test set here because ISRUCDataset has relatively few
>>> # patients, and calibration set should be different from the validation set
>>> # if the latter is used to pick checkpoint. In general, the calibration set
>>> # should be something exchangeable with the test set. Please refer to the paper.
>>> cal_model.calibrate(cal_dataset=test_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(test_dl, additional_outputs=['y_predset'])
>>> print(get_metrics_fn(cal_model.mode)(
... y_true_all, y_prob_all, metrics=['accuracy', 'error_ps', 'rejection_rate'],
... y_predset=extra_output['y_predset'])
... )
{'accuracy': 0.709843241966832, 'rejection_rate': 0.6381305287631919,
'error_ps': array([0.32161874, 0.36654135, 0.11461734, 0.23728814, 0.14993925])}
calibrate(cal_dataset)[source]#

Calibrate/Search for the thresholds used to construct the prediction set.

Parameters

cal_dataset (Subset) – Calibration set.

forward(**kwargs)[source]#

Forward propagation (just like the original model).

Returns

A dictionary with all results from the base model, with the following updates:

y_predset: a bool tensor representing the prediction for each class.

Return type

Dict[str, torch.Tensor]