pyhealth.calib.predictionset#

Prediction set construction methods

class pyhealth.calib.predictionset.LABEL(model, alpha, debug=False, **kwargs)[source]#

Bases: SetPredictor

LABEL: Least ambiguous set-valued classifiers with bounded error levels.

This is a prediction-set constructor for multi-class classification problems. It controls either \(\mathbb{P}\{Y \not \in C(X) | Y=k\}\leq \alpha_k\) (when alpha is an array), or \(\mathbb{P}\{Y \not \in C(X)\}\leq \alpha\) (when alpha is a float). Here, \(C(X)\) denotes the final prediction set. This is essentially a split conformal prediction method using the predicted scores.

Paper:

Sadinle, Mauricio, Jing Lei, and Larry Wasserman. “Least ambiguous set-valued classifiers with bounded error levels.” Journal of the American Statistical Association 114, no. 525 (2019): 223-234.

Parameters:
  • model (BaseModel) – A trained base model.

  • alpha (Union[float, np.ndarray]) – Target mis-coverage rate(s).

Examples

>>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader
>>> from pyhealth.models import SparcNet
>>> from pyhealth.tasks import sleep_staging_isruc_fn
>>> from pyhealth.calib.predictionset import LABEL
>>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn)
>>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2])
>>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"],
...     label_key="label", mode="multiclass")
>>> # ... Train the model here ...
>>> # Calibrate the set classifier, with different class-specific mis-coverage rates
>>> cal_model = LABEL(model, [0.15, 0.3, 0.15, 0.15, 0.15])
>>> # Note that we used the test set here because ISRUCDataset has relatively few
>>> # patients, and calibration set should be different from the validation set
>>> # if the latter is used to pick checkpoint. In general, the calibration set
>>> # should be something exchangeable with the test set. Please refer to the paper.
>>> cal_model.calibrate(cal_dataset=test_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer, get_metrics_fn
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(test_dl, additional_outputs=['y_predset'])
>>> print(get_metrics_fn(cal_model.mode)(
... y_true_all, y_prob_all, metrics=['accuracy', 'miscoverage_ps'],
... y_predset=extra_output['y_predset'])
... )
{'accuracy': 0.709843241966832, 'miscoverage_ps': array([0.1499847 , 0.29997638, 0.14993964, 0.14994704, 0.14999252])}
calibrate(cal_dataset)[source]#

Calibrate the thresholds used to construct the prediction set.

Parameters:

cal_dataset (Subset) – Calibration set.

forward(**kwargs)[source]#

Forward propagation (just like the original model).

Returns:

A dictionary with all results from the base model, with the following updates:

y_predset: a bool tensor representing the prediction for each class.

Return type:

Dict[str, torch.Tensor]

class pyhealth.calib.predictionset.SCRIB(model, risk, loss_kwargs=None, debug=False, fill_max=True, **kwargs)[source]#

Bases: SetPredictor

SCRIB: Set-classifier with Class-specific Risk Bounds

This is a prediction-set constructor for multi-class classification problems. SCRIB tries to control class-specific risk while minimizing the ambiguity. To to this, it selects class-specific thresholds for the predictions, on a calibration set.

If risk is a float (say 0.1), SCRIB controls the overall risk: \(\mathbb{P}\{Y \not \in C(X) | |C(X)| = 1\}\leq risk\). If risk is an array (say np.asarray([0.1] * 5)), SCRIB controls the class specific risks: \(\mathbb{P}\{Y \not \in C(X) | Y=k \land |C(X)| = 1\}\leq risk_k\) Here, \(C(X)\) denotes the final prediction set.

Paper:

Lin, Zhen, Lucas Glass, M. Brandon Westover, Cao Xiao, and Jimeng Sun. “SCRIB: Set-classifier with Class-specific Risk Bounds for Blackbox Models.” AAAI 2022.

Parameters:
  • model (BaseModel) – A trained model.

  • risk (Union[float, np.ndarray]) – risk targets.

  • loss_kwargs (dict, optional) –

    Additional loss parameters (including hyperparameters). It could contain the following float/int hyperparameters:

    lk: The coefficient for the loss term associated with risk violation penalty.

    The higher the lk, the more penalty on risk violation (likely higher ambiguity).

    fill_max: Whether to fill the class with max predicted score

    when no class exceeds the threshold. In other words, if fill_max, the null region will be filled with max-prediction class.

    Defaults to {‘lk’: 1e4, ‘fill_max’: False}

  • fill_max (bool, optional) – Whether to fill the empty prediction set with the max-predicted class. Defaults to True.

Examples

>>> from pyhealth.data import ISRUCDataset, split_by_patient, get_dataloader
>>> from pyhealth.models import SparcNet
>>> from pyhealth.tasks import sleep_staging_isruc_fn
>>> from pyhealth.calib.predictionset import SCRIB
>>> from pyhealth.trainer import get_metrics_fn
>>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn)
>>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2])
>>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"],
...     label_key="label", mode="multiclass")
>>> # ... Train the model here ...
>>> # Calibrate the set classifier, with different class-specific risk targets
>>> cal_model = SCRIB(model, [0.2, 0.3, 0.1, 0.2, 0.1])
>>> # Note that we used the test set here because ISRUCDataset has relatively few
>>> # patients, and calibration set should be different from the validation set
>>> # if the latter is used to pick checkpoint. In general, the calibration set
>>> # should be something exchangeable with the test set. Please refer to the paper.
>>> cal_model.calibrate(cal_dataset=test_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(test_dl, additional_outputs=['y_predset'])
>>> print(get_metrics_fn(cal_model.mode)(
... y_true_all, y_prob_all, metrics=['accuracy', 'error_ps', 'rejection_rate'],
... y_predset=extra_output['y_predset'])
... )
{'accuracy': 0.709843241966832, 'rejection_rate': 0.6381305287631919,
'error_ps': array([0.32161874, 0.36654135, 0.11461734, 0.23728814, 0.14993925])}
calibrate(cal_dataset)[source]#

Calibrate/Search for the thresholds used to construct the prediction set.

Parameters:

cal_dataset (Subset) – Calibration set.

forward(**kwargs)[source]#

Forward propagation (just like the original model).

Returns:

A dictionary with all results from the base model, with the following updates:

y_predset: a bool tensor representing the prediction for each class.

Return type:

Dict[str, torch.Tensor]

class pyhealth.calib.predictionset.FavMac(model, value_weights=1.0, cost_weights=1.0, target_cost=1.0, delta=None, debug=False, **kwargs)[source]#

Bases: SetPredictor

Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac)

This is a prediction-set constructor for multi-label classification problems. FavMac could control the cost/risk while realizing high value on the prediction set.

Value and cost functions are functions in the form of \(V(S;Y)\) or \(C(S;Y)\), with S being the prediction set and Y being the label. For example, a classical cost function would be “numebr of false positives”. Denote the target_cost as \(c\), if delta=None, FavMac controls the expected cost in the following sense:

\(\mathbb{E}[C(S_{N+1};Y_{N+1}] \leq c\).

Otherwise, FavMac controls the violation probability in the following sense:

\(\mathbb{P}\{C(S_{N+1};Y_{N+1})>c\}\leq delta\).

Right now, this FavMac implementation only supports additive value and cost functions (unlike the implementation associated with [1]). That is, the value function is specified by the weights value_weights and the cost function is specified by cost_weights. With \(k\) denoting classes, the cost function is then computed as

\(C(S;Y,w) = \sum_{k} (1-Y_k)S_k w_k\)

Similarly, the value function is computed as

\(V(S;Y,w) = \sum_{k} Y_k S_k w_k\).

Papers:

[1] Lin, Zhen, Shubhendu Trivedi, Cao Xiao, and Jimeng Sun. “Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac).” ICML 2023.

[2] Fisch, Adam, Tal Schuster, Tommi Jaakkola, and Regina Barzilay. “Conformal prediction sets with limited false positives.” ICML 2022.

Parameters:
  • model (BaseModel) – A trained model.

  • value_weights (Union[float, np.ndarray]) – weights for the value function. See description above. Defaults to 1.

  • cost_weights (Union[float, np.ndarray]) – weights for the cost function. See description above. Defaults to 1.

  • target_cost (float) – Target cost. When cost_weights is set to 1, this is essentially the number of false positive. Defaults to 1.

  • delta (float) – Violation target (in violation control). Defaults to None (which means expectation control instead of violation control).

Examples

>>> from pyhealth.calib.predictionset import FavMac
>>> from pyhealth.datasets import (MIMIC3Dataset, get_dataloader,split_by_patient)
>>> from pyhealth.models import Transformer
>>> from pyhealth.tasks import drug_recommendation_mimic3_fn
>>> from pyhealth.trainer import get_metrics_fn
>>> base_dataset = MIMIC3Dataset(
...     root="/srv/scratch1/data/physionet.org/files/mimiciii/1.4",
...     tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
...     code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
...     refresh_cache=False)
>>> sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn)
>>> train_data, val_data, test_data = split_by_patient(sample_dataset, [0.6, 0.2, 0.2])
>>> model = Transformer(dataset=sample_dataset, feature_keys=["conditions", "procedures"],
...             label_key="drugs", mode="multilabel")
>>> # ... Train the model here ...
>>> # Try to control false positive to <=3
>>> cal_model = FavMac(model, target_cost=3, delta=None)
>>> cal_model.calibrate(cal_dataset=val_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(
... test_dl, additional_outputs=["y_predset"])
>>> print(get_metrics_fn(cal_model.mode)(
...     y_true_all, y_prob_all, metrics=['tp', 'fp'],
...     y_predset=extra_output["y_predset"])) # We get FP~=3
{'tp': 0.5049893086243763, 'fp': 2.8442622950819674}
calibrate(cal_dataset)[source]#

Calibrate the cost-control procedure.

Parameters:

cal_dataset (Subset) – Calibration set.

forward(**kwargs)[source]#

Forward propagation (just like the original model).

Returns:

A dictionary with all results from the base model, with the following updates:

y_predset: a bool tensor representing the prediction for each class.

Return type:

Dict[str, torch.Tensor]