Trainer#
- class pyhealth.trainer.Trainer(model, checkpoint_path=None, metrics=None, device=None, enable_logging=True, output_path=None, exp_name=None)[source]#
Bases:
object
Trainer for PyTorch models.
- Parameters
model (
Module
) – PyTorch model.checkpoint_path (
Optional
[str
]) – Path to the checkpoint. Default is None, which means the model will be randomly initialized.metrics (
Optional
[List
[str
]]) – List of metric names to be calculated. Default is None, which means the default metrics in each metrics_fn will be used.device (
Optional
[str
]) – Device to be used for training. Default is None, which means the device will be GPU if available, otherwise CPU.enable_logging (
bool
) – Whether to enable logging. Default is True.output_path (
Optional
[str
]) – Path to save the output. Default is “./output”.exp_name (
Optional
[str
]) – Name of the experiment. Default is current datetime.
- train(train_dataloader, val_dataloader=None, test_dataloader=None, epochs=5, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_params=None, weight_decay=0.0, max_grad_norm=None, monitor=None, monitor_criterion='max', load_best_model_at_last=True)[source]#
Trains the model.
- Parameters
train_dataloader (
DataLoader
) – Dataloader for training.val_dataloader (
Optional
[DataLoader
]) – Dataloader for validation. Default is None.test_dataloader (
Optional
[DataLoader
]) – Dataloader for testing. Default is None.epochs (
int
) – Number of epochs. Default is 5.optimizer_class (
Type
[Optimizer
]) – Optimizer class. Default is torch.optim.Adam.optimizer_params (
Optional
[Dict
[str
,object
]]) – Parameters for the optimizer. Default is {“lr”: 1e-3}.weight_decay (
float
) – Weight decay. Default is 0.0.max_grad_norm (
Optional
[float
]) – Maximum gradient norm. Default is None.monitor (
Optional
[str
]) – Metric name to monitor. Default is None.monitor_criterion (
str
) – Criterion to monitor. Default is “max”.load_best_model_at_last (
bool
) – Whether to load the best model at the last. Default is True.
- inference(dataloader)[source]#
Model inference.
- Parameters
dataloader – Dataloader for evaluation.
- Returns
List of true labels. y_prob_all: List of predicted probabilities. loss_mean: Mean loss over batches.
- Return type
y_true_all