"""
Author: Kobe Guo
NetID: kobeg2
Paper: CaliForest: Calibrated Random Forests for Healthcare Prediction
Link: https://joyceho.github.io/assets/pdf/paper/park-chil20.pdf
Description:
Implementation of CaliForest, a calibrated random forest model that applies
post-hoc calibration (isotonic or logistic) to improve probability estimates
for healthcare prediction tasks.
"""
from __future__ import annotations
from typing import Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
from sklearn.ensemble import RandomForestClassifier
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression
from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel
[docs]class CaliForest(BaseModel):
"""CaliForest model for calibrated probability prediction.
This model wraps a RandomForestClassifier and applies a post-hoc
calibration step using out-of-bag (OOB) predictions and prediction
variance to improve probability estimates.
Important:
CaliForest is fit once on the full training set using fit(train_loader).
After fitting, forward() should be used only for inference/evaluation.
This implementation currently supports binary classification only.
The overall procedure is:
1. train a random forest classifier,
2. compute OOB probabilities for each training sample,
3. estimate prediction uncertainty using variance across tree outputs,
4. fit a calibration model using uncertainty-weighted samples.
Args:
dataset: the dataset used to initialize feature and label schemas.
n_estimators: number of trees in the random forest. Default is 100.
max_depth: maximum depth of each tree. Default is None.
calibration: calibration method. Supported values are ``"isotonic"``
and ``"logistic"``. Default is ``"isotonic"``.
random_state: random seed for reproducibility. Default is 42.
**kwargs: additional compatibility arguments.
Example:
model = CaliForest(dataset=dataset, n_estimators=10)
model.fit(train_loader)
ret = model(**batch)
print(ret["y_prob"].shape)
"""
def __init__(
self,
dataset: SampleDataset,
n_estimators: int = 100,
max_depth: Optional[int] = None,
calibration: str = "isotonic",
random_state: int = 42,
**kwargs,
):
super(CaliForest, self).__init__(dataset)
assert len(self.label_keys) == 1, "Only one label key is supported"
self.label_key = self.label_keys[0]
self.n_estimators = n_estimators
self.max_depth = max_depth
self.calibration = calibration
self.random_state = random_state
if self.calibration not in {"isotonic", "logistic"}:
raise ValueError(f"Unsupported calibration: {self.calibration}")
self.rf = RandomForestClassifier(
n_estimators=self.n_estimators,
max_depth=self.max_depth,
bootstrap=True,
oob_score=True,
random_state=self.random_state,
)
self.calibrator = None
self.is_fitted = False
def _build_feature_matrix(self, **kwargs) -> np.ndarray:
"""Convert PyHealth batch into NumPy feature matrix."""
features: List[np.ndarray] = []
for key in self.feature_keys:
x = kwargs[key]
if isinstance(x, torch.Tensor):
arr = x.detach().cpu().numpy()
else:
arr = np.asarray(x)
if arr.ndim == 1:
arr = arr.reshape(-1, 1)
elif arr.ndim > 2:
arr = arr.reshape(arr.shape[0], -1)
features.append(arr.astype(np.float32))
return np.concatenate(features, axis=1)
def _build_labels(self, **kwargs) -> np.ndarray:
y = kwargs[self.label_key]
if isinstance(y, torch.Tensor):
y = y.detach().cpu().numpy()
else:
y = np.asarray(y)
return y.reshape(-1)
[docs] def fit(self, train_loader):
"""Fit CaliForest on the full training dataloader"""
X_list = []
y_list = []
for batch in train_loader:
X_list.append(self._build_feature_matrix(**batch))
y_list.append(self._build_labels(**batch))
X = np.concatenate(X_list, axis=0)
y = np.concatenate(y_list, axis=0)
self.fit_model(features=X, labels=y)
return self
[docs] def fit_model(self, **kwargs) -> None:
"""Fit RF + calibration model."""
if "features" in kwargs and "labels" in kwargs:
X = kwargs["features"]
y = kwargs["labels"]
else:
X = self._build_feature_matrix(**kwargs)
y = self._build_labels(**kwargs)
unique_labels = np.unique(y)
if set(unique_labels.tolist()) != {0, 1}:
raise ValueError(
"CaliForest currently supports binary classification only. "
f"Got labels: {unique_labels.tolist()}"
)
self.rf.fit(X, y)
if not hasattr(self.rf, "oob_decision_function_"):
raise RuntimeError("OOB predictions not available.")
oob_probs = self.rf.oob_decision_function_[:, 1]
tree_probs = np.stack(
[t.predict_proba(X)[:, 1] for t in self.rf.estimators_],
axis=0,
)
variances = np.var(tree_probs, axis=0)
# CaliForest uses inverse tree-level variance so more stable
# predictions have greater influence during calibrator fitting.
weights = 1.0 / (variances + 1e-6)
if self.calibration == "isotonic":
calibrator = IsotonicRegression(out_of_bounds="clip")
calibrator.fit(oob_probs, y, sample_weight=weights)
self.calibrator = calibrator
else:
calibrator = LogisticRegression()
calibrator.fit(
oob_probs.reshape(-1, 1),
y,
sample_weight=weights,
)
self.calibrator = calibrator
self.is_fitted = True
[docs] def predict_proba_numpy(self, **kwargs) -> np.ndarray:
"""Predict calibrated probabilities."""
if not self.is_fitted:
raise RuntimeError("Model must be fitted first.")
X = self._build_feature_matrix(**kwargs)
rf_probs = self.rf.predict_proba(X)[:, 1]
if self.calibration == "isotonic":
calibrated = self.calibrator.predict(rf_probs)
else:
calibrated = self.calibrator.predict_proba(
rf_probs.reshape(-1, 1)
)[:, 1]
return calibrated.reshape(-1, 1)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""PyHealth forward pass."""
if not self.is_fitted:
raise RuntimeError(
"CaliForest must be fitted before inference. "
"Call model.fit(train_loader) first."
)
y_prob_np = self.predict_proba_numpy(**kwargs)
y_prob = torch.tensor(
y_prob_np, dtype=torch.float32, device=self.device
)
eps = 1e-6
logits = torch.log(
torch.clamp(y_prob, eps, 1 - eps) /
torch.clamp(1 - y_prob, eps, 1 - eps)
)
y_true = kwargs[self.label_key].to(self.device)
loss = self.get_loss_function()(logits, y_true)
return {
"loss": loss,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
}