"""
Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac)
Implementation based on https://github.com/zlin7/FavMac
"""
from importlib import reload
import time
from typing import Dict, Union
import numpy as np
import pandas as pd
import torch
from pyhealth.calib.base_classes import SetPredictor
from pyhealth.calib.predictionset.favmac.core import FavMac_GreedyRatio
from pyhealth.calib.utils import prepare_numpy_dataset
from pyhealth.models import BaseModel
__all__ = ["FavMac"]
INTEGER_SAFE_DELTA = 0.1
class AdditiveSetFunction:
def __init__(self, values: Union[float, np.ndarray, int], mode=None, name='unknown') -> None:
self.name = name
self.values = values
assert mode is None or mode in {'util', 'cost', 'proxy'}
self.mode = mode
def is_additive(self):
return True
def __call__(self, S: np.ndarray, Y:np.ndarray=None, pred:np.ndarray=None, sample=100, target_cost=None) -> float:
if self.mode == 'cost':
assert pred is None
return self.cost_call(S, Y)
if self.mode == 'proxy':
assert Y is None
return self.proxy_call(S, pred, target_cost=target_cost)
assert self.mode == 'util'
return self.util_call(S, Y, pred, sample=sample)
def naive_call(self, S: np.ndarray) -> float:
return np.sum(S * self.values)
def util_call(self, S: np.ndarray, Y:np.ndarray=None, pred:np.ndarray=None, sample=1000) -> float:
assert Y is None or pred is None
if pred is not None:
return self.naive_call(S * pred) # This is because of additivity.
if Y is not None: return self.naive_call(S * Y)
return self.naive_call(S)
def cost_call(self, S: np.ndarray, Y:np.ndarray) -> float:
return self.naive_call(S * (1-Y))
def proxy_call(self, S: np.ndarray, pred: np.ndarray, target_cost: float=None) -> float:
return self.naive_call(S * (1-pred))
def greedy_maximize(self, S: np.ndarray, pred: np.ndarray=None, d_proxy:np.ndarray=None, prev_util_and_proxy=None):
# (prev_u, prev_p) = prev_util_and_proxy
assert self.mode == 'util', "This is only used for util function"
if (1-S).sum() == 0: return None
d_util = self.values
if pred is not None: d_util = d_util * pred
objective = d_util / (1 if d_proxy is None else d_proxy.clip(1e-8))
k = pd.Series((1-S) * objective).dropna().idxmax()
return k, objective[k]
def greedy_maximize_seq(self, pred: np.ndarray=None, d_proxy:np.ndarray=None):
# if cost is also additive, then cost_proxy is fixed: weight * (1-p)
assert self.mode == 'util', "This is only used for util function"
d_util = self.values
if pred is not None: d_util = d_util * pred
objective = d_util / (1 if d_proxy is None else d_proxy.clip(1e-8))
assert np.isnan(objective).sum() == 0
ks = np.argsort(-objective)
Ss = [np.zeros(len(objective), dtype=int)]
for k in ks:
Ss.append(Ss[-1].copy())
Ss[-1][k] = 1
return Ss, ks
[docs]class FavMac(SetPredictor):
"""Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac)
This is a prediction-set constructor for multi-label classification problems.
FavMac could control the cost/risk while realizing high value on the prediction set.
Value and cost functions are functions in the form of :math:`V(S;Y)` or :math:`C(S;Y)`,
with S being the prediction set and Y being the label.
For example, a classical cost function would be "numebr of false positives".
Denote the ``target_cost`` as
:math:`c`,
if ``delta=None``, FavMac controls the expected cost in the following sense:
:math:`\\mathbb{E}[C(S_{N+1};Y_{N+1}] \\leq c`.
Otherwise, FavMac controls the violation probability in the following sense:
:math:`\\mathbb{P}\\{C(S_{N+1};Y_{N+1})>c\\}\\leq delta`.
Right now, this FavMac implementation only supports additive value and cost functions (unlike the
implementation associated with [1]).
That is, the value function is specified by the weights ``value_weights`` and the cost function
is specified by ``cost_weights``.
With :math:`k` denoting classes, the cost function is then computed as
:math:`C(S;Y,w) = \\sum_{k} (1-Y_k)S_k w_k`
Similarly, the value function is computed as
:math:`V(S;Y,w) = \\sum_{k} Y_k S_k w_k`.
Papers:
[1] Lin, Zhen, Shubhendu Trivedi, Cao Xiao, and Jimeng Sun.
"Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac)."
ICML 2023.
[2] Fisch, Adam, Tal Schuster, Tommi Jaakkola, and Regina Barzilay.
"Conformal prediction sets with limited false positives."
ICML 2022.
Args:
model (BaseModel): A trained model.
value_weights (Union[float, np.ndarray]):
weights for the value function. See description above.
Defaults to 1.
cost_weights (Union[float, np.ndarray]):
weights for the cost function. See description above.
Defaults to 1.
target_cost (float): Target cost.
When cost_weights is set to 1, this is essentially the number of false positive.
Defaults to 1.
delta (float): Violation target (in violation control).
Defaults to None (which means expectation control instead of violation control).
Examples:
>>> from pyhealth.calib.predictionset import FavMac
>>> from pyhealth.datasets import (MIMIC3Dataset, get_dataloader,split_by_patient)
>>> from pyhealth.models import Transformer
>>> from pyhealth.tasks import drug_recommendation_mimic3_fn
>>> from pyhealth.trainer import get_metrics_fn
>>> base_dataset = MIMIC3Dataset(
... root="/srv/scratch1/data/physionet.org/files/mimiciii/1.4",
... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
... code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
... refresh_cache=False)
>>> sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn)
>>> train_data, val_data, test_data = split_by_patient(sample_dataset, [0.6, 0.2, 0.2])
>>> model = Transformer(dataset=sample_dataset, feature_keys=["conditions", "procedures"],
... label_key="drugs", mode="multilabel")
>>> # ... Train the model here ...
>>> # Try to control false positive to <=3
>>> cal_model = FavMac(model, target_cost=3, delta=None)
>>> cal_model.calibrate(cal_dataset=val_data)
>>> # Evaluate
>>> from pyhealth.trainer import Trainer
>>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
>>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(
... test_dl, additional_outputs=["y_predset"])
>>> print(get_metrics_fn(cal_model.mode)(
... y_true_all, y_prob_all, metrics=['tp', 'fp'],
... y_predset=extra_output["y_predset"])) # We get FP~=3
{'tp': 0.5049893086243763, 'fp': 2.8442622950819674}
"""
def __init__(
self,
model: BaseModel,
value_weights: Union[float, np.ndarray] = 1.,
cost_weights: Union[float, np.ndarray] = 1.,
target_cost: float = 1., delta:float = None,
debug=False,
**kwargs,
) -> None:
super().__init__(model, **kwargs)
if model.mode != "multilabel":
raise NotImplementedError()
self.mode = self.model.mode # multilabel
for param in model.parameters():
param.requires_grad = False
self.model.eval()
self.device = model.device
self.debug = debug
self._cost_weights = cost_weights
self._value_weights = value_weights
self.target_cost = target_cost
self.delta = delta
[docs] def calibrate(self, cal_dataset):
"""Calibrate the cost-control procedure.
:param cal_dataset: Calibration set.
:type cal_dataset: Subset
"""
_cal_data = prepare_numpy_dataset(
self.model, cal_dataset, ["logit", "y_true"], debug=self.debug
)
if isinstance(self._cost_weights, np.ndarray):
C_max = self._cost_weights.sum()
else:
C_max = _cal_data["logit"].shape[1] * self._cost_weights
self._favmac = FavMac_GreedyRatio(
cost_fn=AdditiveSetFunction(self._cost_weights / C_max, mode='cost'),
util_fn=AdditiveSetFunction(self._value_weights, mode='util'),
proxy_fn=AdditiveSetFunction(self._cost_weights / C_max, mode='proxy'),
target_cost=self.target_cost/C_max, delta=self.delta, C_max=1.,
)
self._favmac.init_calibrate(_cal_data["logit"], _cal_data["y_true"])
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation (just like the original model).
:return: A dictionary with all results from the base model, with the following updates:
y_predset: a bool tensor representing the prediction for each class.
:rtype: Dict[str, torch.Tensor]
"""
ret = self.model(**kwargs)
_logit = ret["logit"].cpu().numpy()
y_predset = np.asarray([self._favmac(_)[0] for _ in _logit])
ret["y_predset"] = torch.tensor(y_predset)
return ret
if __name__ == "__main__":
from pyhealth.calib.predictionset import FavMac
from pyhealth.datasets import (MIMIC3Dataset, get_dataloader,
split_by_patient)
from pyhealth.models import Transformer
from pyhealth.tasks import drug_recommendation_mimic3_fn
from pyhealth.trainer import get_metrics_fn
base_dataset = MIMIC3Dataset(
root="/srv/scratch1/data/physionet.org/files/mimiciii/1.4",
tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
refresh_cache=False,
)
sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn)
train_data, val_data, test_data = split_by_patient(sample_dataset, [0.6, 0.2, 0.2])
model = Transformer(dataset=sample_dataset, feature_keys=["conditions", "procedures"],
label_key="drugs", mode="multilabel")
# ... Train the model here ...
# calibrate the prediction sets with FavMac
cal_model = FavMac(model, cost_weights=1., target_cost=3, delta=None)
cal_model.calibrate(cal_dataset=val_data)
# Evaluate
from pyhealth.trainer import Trainer
test_dl = get_dataloader(test_data, batch_size=32, shuffle=False)
y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference(
test_dl, additional_outputs=["y_predset"]
)
print(
get_metrics_fn(cal_model.mode)(
y_true_all, y_prob_all, metrics=['tp', 'fp'],
y_predset=extra_output["y_predset"],
)
)