pyhealth.calib.predictionset#
Prediction set construction methods
- class pyhealth.calib.predictionset.LABEL(model, alpha, debug=False, **kwargs)[source]#
Bases:
SetPredictor
LABEL: Least ambiguous set-valued classifiers with bounded error levels.
This is a prediction-set constructor for multi-class classification problems. It controls either \(\mathbb{P}\{Y \not \in C(X) | Y=k\}\leq \alpha_k\) (when
alpha
is an array), or \(\mathbb{P}\{Y \not \in C(X)\}\leq \alpha\) (whenalpha
is a float). Here, \(C(X)\) denotes the final prediction set. This is essentially a split conformal prediction method using the predicted scores.Paper:
Sadinle, Mauricio, Jing Lei, and Larry Wasserman. “Least ambiguous set-valued classifiers with bounded error levels.” Journal of the American Statistical Association 114, no. 525 (2019): 223-234.
- Parameters:
model (BaseModel) – A trained base model.
alpha (Union[float, np.ndarray]) – Target mis-coverage rate(s).
Examples
>>> from pyhealth.datasets import ISRUCDataset, split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.predictionset import LABEL >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate the set classifier, with different class-specific mis-coverage rates >>> cal_model = LABEL(model, [0.15, 0.3, 0.15, 0.15, 0.15]) >>> # Note that we used the test set here because ISRUCDataset has relatively few >>> # patients, and calibration set should be different from the validation set >>> # if the latter is used to pick checkpoint. In general, the calibration set >>> # should be something exchangeable with the test set. Please refer to the paper. >>> cal_model.calibrate(cal_dataset=test_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer, get_metrics_fn >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(test_dl, additional_outputs=['y_predset']) >>> print(get_metrics_fn(cal_model.mode)( ... y_true_all, y_prob_all, metrics=['accuracy', 'miscoverage_ps'], ... y_predset=extra_output['y_predset']) ... ) {'accuracy': 0.709843241966832, 'miscoverage_ps': array([0.1499847 , 0.29997638, 0.14993964, 0.14994704, 0.14999252])}
- class pyhealth.calib.predictionset.SCRIB(model, risk, loss_kwargs=None, debug=False, fill_max=True, **kwargs)[source]#
Bases:
SetPredictor
SCRIB: Set-classifier with Class-specific Risk Bounds
This is a prediction-set constructor for multi-class classification problems. SCRIB tries to control class-specific risk while minimizing the ambiguity. To to this, it selects class-specific thresholds for the predictions, on a calibration set.
If
risk
is a float (say 0.1), SCRIB controls the overall risk: \(\mathbb{P}\{Y \not \in C(X) | |C(X)| = 1\}\leq risk\). Ifrisk
is an array (say np.asarray([0.1] * 5)), SCRIB controls the class specific risks: \(\mathbb{P}\{Y \not \in C(X) | Y=k \land |C(X)| = 1\}\leq risk_k\) Here, \(C(X)\) denotes the final prediction set.Paper:
Lin, Zhen, Lucas Glass, M. Brandon Westover, Cao Xiao, and Jimeng Sun. “SCRIB: Set-classifier with Class-specific Risk Bounds for Blackbox Models.” AAAI 2022.
- Parameters:
model (BaseModel) – A trained model.
risk (Union[float, np.ndarray]) – risk targets.
loss_kwargs (dict, optional) –
Additional loss parameters (including hyperparameters). It could contain the following float/int hyperparameters:
- lk: The coefficient for the loss term associated with risk violation penalty.
The higher the lk, the more penalty on risk violation (likely higher ambiguity).
- fill_max: Whether to fill the class with max predicted score
when no class exceeds the threshold. In other words, if fill_max, the null region will be filled with max-prediction class.
Defaults to {‘lk’: 1e4, ‘fill_max’: False}
fill_max (bool, optional) – Whether to fill the empty prediction set with the max-predicted class. Defaults to True.
Examples
>>> from pyhealth.data import ISRUCDataset, split_by_patient, get_dataloader >>> from pyhealth.models import SparcNet >>> from pyhealth.tasks import sleep_staging_isruc_fn >>> from pyhealth.calib.predictionset import SCRIB >>> from pyhealth.trainer import get_metrics_fn >>> sleep_ds = ISRUCDataset("/srv/scratch1/data/ISRUC-I").set_task(sleep_staging_isruc_fn) >>> train_data, val_data, test_data = split_by_patient(sleep_ds, [0.6, 0.2, 0.2]) >>> model = SparcNet(dataset=sleep_ds, feature_keys=["signal"], ... label_key="label", mode="multiclass") >>> # ... Train the model here ... >>> # Calibrate the set classifier, with different class-specific risk targets >>> cal_model = SCRIB(model, [0.2, 0.3, 0.1, 0.2, 0.1]) >>> # Note that we used the test set here because ISRUCDataset has relatively few >>> # patients, and calibration set should be different from the validation set >>> # if the latter is used to pick checkpoint. In general, the calibration set >>> # should be something exchangeable with the test set. Please refer to the paper. >>> cal_model.calibrate(cal_dataset=test_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(test_dl, additional_outputs=['y_predset']) >>> print(get_metrics_fn(cal_model.mode)( ... y_true_all, y_prob_all, metrics=['accuracy', 'error_ps', 'rejection_rate'], ... y_predset=extra_output['y_predset']) ... ) {'accuracy': 0.709843241966832, 'rejection_rate': 0.6381305287631919, 'error_ps': array([0.32161874, 0.36654135, 0.11461734, 0.23728814, 0.14993925])}
- class pyhealth.calib.predictionset.FavMac(model, value_weights=1.0, cost_weights=1.0, target_cost=1.0, delta=None, debug=False, **kwargs)[source]#
Bases:
SetPredictor
Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac)
This is a prediction-set constructor for multi-label classification problems. FavMac could control the cost/risk while realizing high value on the prediction set.
Value and cost functions are functions in the form of \(V(S;Y)\) or \(C(S;Y)\), with S being the prediction set and Y being the label. For example, a classical cost function would be “numebr of false positives”. Denote the
target_cost
as \(c\), ifdelta=None
, FavMac controls the expected cost in the following sense:\(\mathbb{E}[C(S_{N+1};Y_{N+1}] \leq c\).
Otherwise, FavMac controls the violation probability in the following sense:
\(\mathbb{P}\{C(S_{N+1};Y_{N+1})>c\}\leq delta\).
Right now, this FavMac implementation only supports additive value and cost functions (unlike the implementation associated with [1]). That is, the value function is specified by the weights
value_weights
and the cost function is specified bycost_weights
. With \(k\) denoting classes, the cost function is then computed as\(C(S;Y,w) = \sum_{k} (1-Y_k)S_k w_k\)
Similarly, the value function is computed as
\(V(S;Y,w) = \sum_{k} Y_k S_k w_k\).
Papers:
[1] Lin, Zhen, Shubhendu Trivedi, Cao Xiao, and Jimeng Sun. “Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac).” ICML 2023.
[2] Fisch, Adam, Tal Schuster, Tommi Jaakkola, and Regina Barzilay. “Conformal prediction sets with limited false positives.” ICML 2022.
- Parameters:
model (BaseModel) – A trained model.
value_weights (Union[float, np.ndarray]) – weights for the value function. See description above. Defaults to 1.
cost_weights (Union[float, np.ndarray]) – weights for the cost function. See description above. Defaults to 1.
target_cost (float) – Target cost. When cost_weights is set to 1, this is essentially the number of false positive. Defaults to 1.
delta (float) – Violation target (in violation control). Defaults to None (which means expectation control instead of violation control).
Examples
>>> from pyhealth.calib.predictionset import FavMac >>> from pyhealth.datasets import (MIMIC3Dataset, get_dataloader,split_by_patient) >>> from pyhealth.models import Transformer >>> from pyhealth.tasks import drug_recommendation_mimic3_fn >>> from pyhealth.trainer import get_metrics_fn >>> base_dataset = MIMIC3Dataset( ... root="/srv/scratch1/data/physionet.org/files/mimiciii/1.4", ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], ... code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, ... refresh_cache=False) >>> sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn) >>> train_data, val_data, test_data = split_by_patient(sample_dataset, [0.6, 0.2, 0.2]) >>> model = Transformer(dataset=sample_dataset, feature_keys=["conditions", "procedures"], ... label_key="drugs", mode="multilabel") >>> # ... Train the model here ... >>> # Try to control false positive to <=3 >>> cal_model = FavMac(model, target_cost=3, delta=None) >>> cal_model.calibrate(cal_dataset=val_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference( ... test_dl, additional_outputs=["y_predset"]) >>> print(get_metrics_fn(cal_model.mode)( ... y_true_all, y_prob_all, metrics=['tp', 'fp'], ... y_predset=extra_output["y_predset"])) # We get FP~=3 {'tp': 0.5049893086243763, 'fp': 2.8442622950819674}