Source code for pyhealth.calib.predictionset.favmac

Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac)

Implementation based on


from importlib import reload
import time
from typing import Dict, Union

import numpy as np
import pandas as pd
import torch

from pyhealth.calib.base_classes import SetPredictor
from pyhealth.calib.predictionset.favmac.core import FavMac_GreedyRatio
from pyhealth.calib.utils import prepare_numpy_dataset
from pyhealth.models import BaseModel

__all__ = ["FavMac"]


class AdditiveSetFunction:
    def __init__(self, values: Union[float, np.ndarray, int], mode=None, name='unknown') -> None: = name
        self.values = values
        assert mode is None or mode in {'util', 'cost', 'proxy'}
        self.mode = mode

    def is_additive(self):
        return True

    def __call__(self, S: np.ndarray, Y:np.ndarray=None, pred:np.ndarray=None, sample=100, target_cost=None) -> float:
        if self.mode == 'cost':
            assert pred is None
            return self.cost_call(S, Y)
        if self.mode == 'proxy':
            assert Y is None
            return self.proxy_call(S, pred, target_cost=target_cost)
        assert self.mode == 'util'
        return self.util_call(S, Y, pred, sample=sample)

    def naive_call(self, S: np.ndarray) -> float:
        return np.sum(S * self.values)

    def util_call(self, S: np.ndarray, Y:np.ndarray=None, pred:np.ndarray=None, sample=1000) -> float:
        assert Y is None or pred is None
        if pred is not None:
            return self.naive_call(S * pred) # This is because of additivity.
        if Y is not None: return self.naive_call(S * Y)
        return self.naive_call(S)

    def cost_call(self, S: np.ndarray, Y:np.ndarray) -> float:
        return self.naive_call(S * (1-Y))

    def proxy_call(self, S: np.ndarray, pred: np.ndarray, target_cost: float=None) -> float:
        return self.naive_call(S * (1-pred))

    def greedy_maximize(self, S: np.ndarray, pred: np.ndarray=None, d_proxy:np.ndarray=None, prev_util_and_proxy=None):
        # (prev_u, prev_p) = prev_util_and_proxy
        assert self.mode == 'util', "This is only used for util function"
        if (1-S).sum() == 0: return None
        d_util = self.values
        if pred is not None: d_util = d_util * pred

        objective = d_util / (1 if d_proxy is None else d_proxy.clip(1e-8))
        k = pd.Series((1-S) * objective).dropna().idxmax()
        return k, objective[k]

    def greedy_maximize_seq(self, pred: np.ndarray=None, d_proxy:np.ndarray=None):
        # if cost is also additive, then cost_proxy is fixed: weight * (1-p)
        assert self.mode == 'util', "This is only used for util function"
        d_util = self.values
        if pred is not None: d_util = d_util * pred

        objective = d_util / (1 if d_proxy is None else d_proxy.clip(1e-8))

        assert np.isnan(objective).sum() == 0
        ks = np.argsort(-objective)
        Ss = [np.zeros(len(objective), dtype=int)]
        for k in ks:
            Ss[-1][k] = 1
        return Ss, ks

[docs]class FavMac(SetPredictor): """Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac) This is a prediction-set constructor for multi-label classification problems. FavMac could control the cost/risk while realizing high value on the prediction set. Value and cost functions are functions in the form of :math:`V(S;Y)` or :math:`C(S;Y)`, with S being the prediction set and Y being the label. For example, a classical cost function would be "numebr of false positives". Denote the ``target_cost`` as :math:`c`, if ``delta=None``, FavMac controls the expected cost in the following sense: :math:`\\mathbb{E}[C(S_{N+1};Y_{N+1}] \\leq c`. Otherwise, FavMac controls the violation probability in the following sense: :math:`\\mathbb{P}\\{C(S_{N+1};Y_{N+1})>c\\}\\leq delta`. Right now, this FavMac implementation only supports additive value and cost functions (unlike the implementation associated with [1]). That is, the value function is specified by the weights ``value_weights`` and the cost function is specified by ``cost_weights``. With :math:`k` denoting classes, the cost function is then computed as :math:`C(S;Y,w) = \\sum_{k} (1-Y_k)S_k w_k` Similarly, the value function is computed as :math:`V(S;Y,w) = \\sum_{k} Y_k S_k w_k`. Papers: [1] Lin, Zhen, Shubhendu Trivedi, Cao Xiao, and Jimeng Sun. "Fast Online Value-Maximizing Prediction Sets with Conformal Cost Control (FavMac)." ICML 2023. [2] Fisch, Adam, Tal Schuster, Tommi Jaakkola, and Regina Barzilay. "Conformal prediction sets with limited false positives." ICML 2022. Args: model (BaseModel): A trained model. value_weights (Union[float, np.ndarray]): weights for the value function. See description above. Defaults to 1. cost_weights (Union[float, np.ndarray]): weights for the cost function. See description above. Defaults to 1. target_cost (float): Target cost. When cost_weights is set to 1, this is essentially the number of false positive. Defaults to 1. delta (float): Violation target (in violation control). Defaults to None (which means expectation control instead of violation control). Examples: >>> from pyhealth.calib.predictionset import FavMac >>> from pyhealth.datasets import (MIMIC3Dataset, get_dataloader,split_by_patient) >>> from pyhealth.models import Transformer >>> from pyhealth.tasks import drug_recommendation_mimic3_fn >>> from pyhealth.trainer import get_metrics_fn >>> base_dataset = MIMIC3Dataset( ... root="/srv/scratch1/data/", ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], ... code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, ... refresh_cache=False) >>> sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn) >>> train_data, val_data, test_data = split_by_patient(sample_dataset, [0.6, 0.2, 0.2]) >>> model = Transformer(dataset=sample_dataset, feature_keys=["conditions", "procedures"], ... label_key="drugs", mode="multilabel") >>> # ... Train the model here ... >>> # Try to control false positive to <=3 >>> cal_model = FavMac(model, target_cost=3, delta=None) >>> cal_model.calibrate(cal_dataset=val_data) >>> # Evaluate >>> from pyhealth.trainer import Trainer >>> test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) >>> y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference( ... test_dl, additional_outputs=["y_predset"]) >>> print(get_metrics_fn(cal_model.mode)( ... y_true_all, y_prob_all, metrics=['tp', 'fp'], ... y_predset=extra_output["y_predset"])) # We get FP~=3 {'tp': 0.5049893086243763, 'fp': 2.8442622950819674} """ def __init__( self, model: BaseModel, value_weights: Union[float, np.ndarray] = 1., cost_weights: Union[float, np.ndarray] = 1., target_cost: float = 1., delta:float = None, debug=False, **kwargs, ) -> None: super().__init__(model, **kwargs) if model.mode != "multilabel": raise NotImplementedError() self.mode = self.model.mode # multilabel for param in model.parameters(): param.requires_grad = False self.model.eval() self.device = model.device self.debug = debug self._cost_weights = cost_weights self._value_weights = value_weights self.target_cost = target_cost = delta
[docs] def calibrate(self, cal_dataset): """Calibrate the cost-control procedure. :param cal_dataset: Calibration set. :type cal_dataset: Subset """ _cal_data = prepare_numpy_dataset( self.model, cal_dataset, ["logit", "y_true"], debug=self.debug ) if isinstance(self._cost_weights, np.ndarray): C_max = self._cost_weights.sum() else: C_max = _cal_data["logit"].shape[1] * self._cost_weights self._favmac = FavMac_GreedyRatio( cost_fn=AdditiveSetFunction(self._cost_weights / C_max, mode='cost'), util_fn=AdditiveSetFunction(self._value_weights, mode='util'), proxy_fn=AdditiveSetFunction(self._cost_weights / C_max, mode='proxy'), target_cost=self.target_cost/C_max,, C_max=1., ) self._favmac.init_calibrate(_cal_data["logit"], _cal_data["y_true"])
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation (just like the original model). :return: A dictionary with all results from the base model, with the following updates: y_predset: a bool tensor representing the prediction for each class. :rtype: Dict[str, torch.Tensor] """ ret = self.model(**kwargs) _logit = ret["logit"].cpu().numpy() y_predset = np.asarray([self._favmac(_)[0] for _ in _logit]) ret["y_predset"] = torch.tensor(y_predset) return ret
if __name__ == "__main__": from pyhealth.calib.predictionset import FavMac from pyhealth.datasets import (MIMIC3Dataset, get_dataloader, split_by_patient) from pyhealth.models import Transformer from pyhealth.tasks import drug_recommendation_mimic3_fn from pyhealth.trainer import get_metrics_fn base_dataset = MIMIC3Dataset( root="/srv/scratch1/data/", tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, refresh_cache=False, ) sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn) train_data, val_data, test_data = split_by_patient(sample_dataset, [0.6, 0.2, 0.2]) model = Transformer(dataset=sample_dataset, feature_keys=["conditions", "procedures"], label_key="drugs", mode="multilabel") # ... Train the model here ... # calibrate the prediction sets with FavMac cal_model = FavMac(model, cost_weights=1., target_cost=3, delta=None) cal_model.calibrate(cal_dataset=val_data) # Evaluate from pyhealth.trainer import Trainer test_dl = get_dataloader(test_data, batch_size=32, shuffle=False) y_true_all, y_prob_all, _, extra_output = Trainer(model=cal_model).inference( test_dl, additional_outputs=["y_predset"] ) print( get_metrics_fn(cal_model.mode)( y_true_all, y_prob_all, metrics=['tp', 'fp'], y_predset=extra_output["y_predset"], ) )