Trainer#
Trainer handles the PyTorch training loop for you.
Rather than writing your own epoch loop, loss backward pass, optimizer step,
and metric evaluation, you hand the Trainer your model and data loaders and
let it manage the details — including early stopping when validation
performance plateaus and automatic reloading of the best checkpoint at the end.
A Typical Training Run#
Here is what a full training setup looks like. The data loaders come from
get_dataloader() in pyhealth.datasets, which knows how to work with
PyHealth’s LitData caching format:
from pyhealth.trainer import Trainer
from pyhealth.datasets import get_dataloader
train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)
test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)
trainer = Trainer(
model=model,
metrics=["roc_auc_macro", "pr_auc_macro", "f1_macro"],
device="cuda",
)
trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
test_dataloader=test_loader,
epochs=50,
monitor="roc_auc_macro",
monitor_criterion="max",
patience=10,
)
scores = trainer.evaluate(test_loader)
# {'roc_auc_macro': 0.85, 'pr_auc_macro': 0.79, 'f1_macro': 0.72, 'loss': 0.31}
Setting Up the Trainer#
Trainer(model, metrics=None, device=None, enable_logging=True, output_path=None, exp_name=None)
model — your instantiated PyHealth model.
metrics — the metric names you want computed at validation and test time (e.g.
["roc_auc_macro", "f1_macro"]). See Metrics for the full list of supported strings.device —
"cuda"or"cpu"; defaults to auto-detecting a GPU.enable_logging — when enabled, the Trainer creates a timestamped folder under
output_pathwith alog.txtand model checkpoints.output_path / exp_name — where and how to name the output folder.
Controlling the Training Loop#
trainer.train() accepts these key arguments beyond the data loaders:
epochs — the maximum number of training epochs.
optimizer_class / optimizer_params — which optimizer to use and how to configure it. Defaults to
Adamwith a learning rate of1e-3.weight_decay — L2 regularisation strength. Default
0.0.max_grad_norm — if set, clips gradients to this norm before each update, which can help stabilise training on noisy medical data.
monitor / monitor_criterion — the metric to watch on the validation set (e.g.
"roc_auc_macro") and whether higher is better ("max") or lower is better ("min"). The Trainer saves a checkpoint whenever this metric improves.patience — how many epochs without improvement to wait before stopping early.
load_best_model_at_last — when
True(the default), the Trainer restores the best checkpoint at the end of training rather than keeping the weights from the final epoch.
Getting the Test Scores#
trainer.train() prints test scores to the console when a
test_dataloader is provided, but it does not return them as a Python
object. To capture results for downstream use, call evaluate() separately:
scores = trainer.evaluate(test_loader)
# scores is a plain dict, e.g. {'roc_auc_macro': 0.85, 'loss': 0.31}
import json
with open("results.json", "w") as f:
json.dump(scores, f, indent=2)
API Reference#
- class pyhealth.trainer.Trainer(model, checkpoint_path=None, metrics=None, device=None, enable_logging=True, output_path=None, exp_name=None)[source]#
Bases:
objectTrainer for PyTorch models.
- Parameters:
model (
Module) – PyTorch model.checkpoint_path (
Optional[str]) – Path to the checkpoint. Default is None, which means the model will be randomly initialized.metrics (
Optional[List[str]]) – List of metric names to be calculated. Default is None, which means the default metrics in each metrics_fn will be used.device (
Optional[str]) – Device to be used for training. Default is None, which means the device will be GPU if available, otherwise CPU.enable_logging (
bool) – Whether to enable logging. Default is True.output_path (
Optional[str]) – Path to save the output. Default is “./output”.exp_name (
Optional[str]) – Name of the experiment. Default is current datetime.
- train(train_dataloader, val_dataloader=None, test_dataloader=None, epochs=5, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_params=None, steps_per_epoch=None, evaluation_steps=1, weight_decay=0.0, max_grad_norm=None, monitor=None, monitor_criterion='max', load_best_model_at_last=True, patience=None)[source]#
Trains the model.
- Parameters:
train_dataloader (
DataLoader) – Dataloader for training.val_dataloader (
Optional[DataLoader]) – Dataloader for validation. Default is None.test_dataloader (
Optional[DataLoader]) – Dataloader for testing. Default is None.epochs (
int) – Number of epochs. Default is 5.optimizer_class (
Type[Optimizer]) – Optimizer class. Default is torch.optim.Adam.optimizer_params (
Optional[Dict[str,object]]) – Parameters for the optimizer. Default is {“lr”: 1e-3}.steps_per_epoch (
int) – Number of steps per epoch. Default is None.weight_decay (
float) – Weight decay. Default is 0.0.max_grad_norm (
float) – Maximum gradient norm. Default is None.monitor (
Optional[str]) – Metric name to monitor. Default is None.monitor_criterion (
str) – Criterion to monitor. Default is “max”.load_best_model_at_last (
bool) – Whether to load the best model at the last. Default is True.patience – Number of epochs to wait for improvement before early stopping. Default is None, which means no early stopping.
- inference(dataloader, additional_outputs=None, return_patient_ids=False)[source]#
Model inference.
- Parameters:
dataloader – Dataloader for evaluation.
additional_outputs – List of additional output to collect. Defaults to None ([]).
- Returns:
List of true labels. y_prob_all: List of predicted probabilities. loss_mean: Mean loss over batches. additional_outputs (only if requested): Dict of additional results. patient_ids (only if requested): List of patient ids in the same order as y_true_all/y_prob_all.
- Return type:
y_true_all