class pyhealth.trainer.Trainer(model, checkpoint_path=None, metrics=None, device=None, enable_logging=True, output_path=None, exp_name=None)[source]#

Bases: object

Trainer for PyTorch models.

  • model (Module) – PyTorch model.

  • checkpoint_path (Optional[str]) – Path to the checkpoint. Default is None, which means the model will be randomly initialized.

  • metrics (Optional[List[str]]) – List of metric names to be calculated. Default is None, which means the default metrics in each metrics_fn will be used.

  • device (Optional[str]) – Device to be used for training. Default is None, which means the device will be GPU if available, otherwise CPU.

  • enable_logging (bool) – Whether to enable logging. Default is True.

  • output_path (Optional[str]) – Path to save the output. Default is “./output”.

  • exp_name (Optional[str]) – Name of the experiment. Default is current datetime.

train(train_dataloader, val_dataloader=None, test_dataloader=None, epochs=5, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_params=None, steps_per_epoch=None, evaluation_steps=1, weight_decay=0.0, max_grad_norm=None, monitor=None, monitor_criterion='max', load_best_model_at_last=True)[source]#

Trains the model.

  • train_dataloader (DataLoader) – Dataloader for training.

  • val_dataloader (Optional[DataLoader]) – Dataloader for validation. Default is None.

  • test_dataloader (Optional[DataLoader]) – Dataloader for testing. Default is None.

  • epochs (int) – Number of epochs. Default is 5.

  • optimizer_class (Type[Optimizer]) – Optimizer class. Default is torch.optim.Adam.

  • optimizer_params (Optional[Dict[str, object]]) – Parameters for the optimizer. Default is {“lr”: 1e-3}.

  • steps_per_epoch (Optional[int]) – Number of steps per epoch. Default is None.

  • weight_decay (float) – Weight decay. Default is 0.0.

  • max_grad_norm (Optional[float]) – Maximum gradient norm. Default is None.

  • monitor (Optional[str]) – Metric name to monitor. Default is None.

  • monitor_criterion (str) – Criterion to monitor. Default is “max”.

  • load_best_model_at_last (bool) – Whether to load the best model at the last. Default is True.

inference(dataloader, additional_outputs=None, return_patient_ids=False)[source]#

Model inference.

  • dataloader – Dataloader for evaluation.

  • additional_outputs – List of additional output to collect. Defaults to None ([]).


List of true labels. y_prob_all: List of predicted probabilities. loss_mean: Mean loss over batches. additional_outputs (only if requested): Dict of additional results. patient_ids (only if requested): List of patient ids in the same order as y_true_all/y_prob_all.

Return type:



Evaluates the model.


dataloader – Dataloader for evaluation.


a dictionary of scores.

Return type:



Saves the model checkpoint.

Return type:



Saves the model checkpoint.

Return type: