import logging
import os
from datetime import datetime
from typing import Callable, Dict, List, Optional, Type
import numpy as np
import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from tqdm.autonotebook import trange
from pyhealth.metrics import (binary_metrics_fn, multiclass_metrics_fn,
multilabel_metrics_fn, regression_metrics_fn)
from pyhealth.utils import create_directory
logger = logging.getLogger(__name__)
def is_best(best_score: float, score: float, monitor_criterion: str) -> bool:
if monitor_criterion == "max":
return score > best_score
elif monitor_criterion == "min":
return score < best_score
else:
raise ValueError(f"Monitor criterion {monitor_criterion} is not supported")
def set_logger(log_path: str) -> None:
create_directory(log_path)
log_filename = os.path.join(log_path, "log.txt")
handler = logging.FileHandler(log_filename)
formatter = logging.Formatter("%(asctime)s %(message)s", "%Y-%m-%d %H:%M:%S")
handler.setFormatter(formatter)
logger.addHandler(handler)
return
def get_metrics_fn(mode: str) -> Callable:
if mode == "binary":
return binary_metrics_fn
elif mode == "multiclass":
return multiclass_metrics_fn
elif mode == "multilabel":
return multilabel_metrics_fn
elif mode == "regression":
return regression_metrics_fn
else:
raise ValueError(f"Mode {mode} is not supported")
[docs]class Trainer:
"""Trainer for PyTorch models.
Args:
model: PyTorch model.
checkpoint_path: Path to the checkpoint. Default is None, which means
the model will be randomly initialized.
metrics: List of metric names to be calculated. Default is None, which
means the default metrics in each metrics_fn will be used.
device: Device to be used for training. Default is None, which means
the device will be GPU if available, otherwise CPU.
enable_logging: Whether to enable logging. Default is True.
output_path: Path to save the output. Default is "./output".
exp_name: Name of the experiment. Default is current datetime.
"""
def __init__(
self,
model: nn.Module,
checkpoint_path: Optional[str] = None,
metrics: Optional[List[str]] = None,
device: Optional[str] = None,
enable_logging: bool = True,
output_path: Optional[str] = None,
exp_name: Optional[str] = None,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = model
self.metrics = metrics
self.device = device
# set logger
if enable_logging:
if output_path is None:
output_path = os.path.join(os.getcwd(), "output")
if exp_name is None:
exp_name = datetime.now().strftime("%Y%m%d-%H%M%S")
self.exp_path = os.path.join(output_path, exp_name)
set_logger(self.exp_path)
else:
self.exp_path = None
# set device
self.model.to(self.device)
# logging
logger.info(self.model)
logger.info(f"Metrics: {self.metrics}")
logger.info(f"Device: {self.device}")
# load checkpoint
if checkpoint_path is not None:
logger.info(f"Loading checkpoint from {checkpoint_path}")
self.load_ckpt(checkpoint_path)
logger.info("")
return
[docs] def train(
self,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
test_dataloader: Optional[DataLoader] = None,
epochs: int = 5,
optimizer_class: Type[Optimizer] = torch.optim.Adam,
optimizer_params: Optional[Dict[str, object]] = None,
steps_per_epoch: int = None,
evaluation_steps: int = 1,
weight_decay: float = 0.0,
max_grad_norm: float = None,
monitor: Optional[str] = None,
monitor_criterion: str = "max",
load_best_model_at_last: bool = True,
):
"""Trains the model.
Args:
train_dataloader: Dataloader for training.
val_dataloader: Dataloader for validation. Default is None.
test_dataloader: Dataloader for testing. Default is None.
epochs: Number of epochs. Default is 5.
optimizer_class: Optimizer class. Default is torch.optim.Adam.
optimizer_params: Parameters for the optimizer. Default is {"lr": 1e-3}.
steps_per_epoch: Number of steps per epoch. Default is None.
weight_decay: Weight decay. Default is 0.0.
max_grad_norm: Maximum gradient norm. Default is None.
monitor: Metric name to monitor. Default is None.
monitor_criterion: Criterion to monitor. Default is "max".
load_best_model_at_last: Whether to load the best model at the last.
Default is True.
"""
if optimizer_params is None:
optimizer_params = {"lr": 1e-3}
# logging
logger.info("Training:")
logger.info(f"Batch size: {train_dataloader.batch_size}")
logger.info(f"Optimizer: {optimizer_class}")
logger.info(f"Optimizer params: {optimizer_params}")
logger.info(f"Weight decay: {weight_decay}")
logger.info(f"Max grad norm: {max_grad_norm}")
logger.info(f"Val dataloader: {val_dataloader}")
logger.info(f"Monitor: {monitor}")
logger.info(f"Monitor criterion: {monitor_criterion}")
logger.info(f"Epochs: {epochs}")
# set optimizer
param = list(self.model.named_parameters())
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in param if not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{
"params": [p for n, p in param if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
# initialize
data_iterator = iter(train_dataloader)
best_score = -1 * float("inf") if monitor_criterion == "max" else float("inf")
if steps_per_epoch == None:
steps_per_epoch = len(train_dataloader)
global_step = 0
# epoch training loop
for epoch in range(epochs):
training_loss = []
self.model.zero_grad()
self.model.train()
# batch training loop
logger.info("")
for _ in trange(
steps_per_epoch,
desc=f"Epoch {epoch} / {epochs}",
smoothing=0.05,
):
try:
data = next(data_iterator)
except StopIteration:
data_iterator = iter(train_dataloader)
data = next(data_iterator)
# forward
output = self.model(**data)
loss = output["loss"]
# backward
loss.backward()
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_grad_norm
)
# update
optimizer.step()
optimizer.zero_grad()
training_loss.append(loss.item())
global_step += 1
# log and save
logger.info(f"--- Train epoch-{epoch}, step-{global_step} ---")
logger.info(f"loss: {sum(training_loss) / len(training_loss):.4f}")
if self.exp_path is not None:
self.save_ckpt(os.path.join(self.exp_path, "last.ckpt"))
# validation
if val_dataloader is not None:
scores = self.evaluate(val_dataloader)
logger.info(f"--- Eval epoch-{epoch}, step-{global_step} ---")
for key in scores.keys():
logger.info("{}: {:.4f}".format(key, scores[key]))
# save best model
if monitor is not None:
score = scores[monitor]
if is_best(best_score, score, monitor_criterion):
logger.info(
f"New best {monitor} score ({score:.4f}) "
f"at epoch-{epoch}, step-{global_step}"
)
best_score = score
if self.exp_path is not None:
self.save_ckpt(os.path.join(self.exp_path, "best.ckpt"))
# load best model
if load_best_model_at_last and self.exp_path is not None and os.path.isfile(
os.path.join(self.exp_path, "best.ckpt")):
logger.info("Loaded best model")
self.load_ckpt(os.path.join(self.exp_path, "best.ckpt"))
# test
if test_dataloader is not None:
scores = self.evaluate(test_dataloader)
logger.info(f"--- Test ---")
for key in scores.keys():
logger.info("{}: {:.4f}".format(key, scores[key]))
return
[docs] def inference(self, dataloader, additional_outputs=None,
return_patient_ids=False) -> Dict[str, float]:
"""Model inference.
Args:
dataloader: Dataloader for evaluation.
additional_outputs: List of additional output to collect.
Defaults to None ([]).
Returns:
y_true_all: List of true labels.
y_prob_all: List of predicted probabilities.
loss_mean: Mean loss over batches.
additional_outputs (only if requested): Dict of additional results.
patient_ids (only if requested): List of patient ids in the same order as y_true_all/y_prob_all.
"""
loss_all = []
y_true_all = []
y_prob_all = []
patient_ids = []
if additional_outputs is not None:
additional_outputs = {k: [] for k in additional_outputs}
for data in tqdm(dataloader, desc="Evaluation"):
self.model.eval()
with torch.no_grad():
output = self.model(**data)
loss = output["loss"]
y_true = output["y_true"].cpu().numpy()
y_prob = output["y_prob"].cpu().numpy()
loss_all.append(loss.item())
y_true_all.append(y_true)
y_prob_all.append(y_prob)
if additional_outputs is not None:
for key in additional_outputs.keys():
additional_outputs[key].append(output[key].cpu().numpy())
if return_patient_ids:
patient_ids.extend(data["patient_id"])
loss_mean = sum(loss_all) / len(loss_all)
y_true_all = np.concatenate(y_true_all, axis=0)
y_prob_all = np.concatenate(y_prob_all, axis=0)
outputs = [y_true_all, y_prob_all, loss_mean]
if additional_outputs is not None:
additional_outputs = {key: np.concatenate(val)
for key, val in additional_outputs.items()}
outputs.append(additional_outputs)
if return_patient_ids:
outputs.append(patient_ids)
return outputs
[docs] def evaluate(self, dataloader) -> Dict[str, float]:
"""Evaluates the model.
Args:
dataloader: Dataloader for evaluation.
Returns:
scores: a dictionary of scores.
"""
if self.model.mode is not None:
y_true_all, y_prob_all, loss_mean = self.inference(dataloader)
mode = self.model.mode
metrics_fn = get_metrics_fn(mode)
scores = metrics_fn(y_true_all, y_prob_all, metrics=self.metrics)
scores["loss"] = loss_mean
else:
loss_all = []
for data in tqdm(dataloader, desc="Evaluation"):
self.model.eval()
with torch.no_grad():
output = self.model(**data)
loss = output["loss"]
loss_all.append(loss.item())
loss_mean = sum(loss_all) / len(loss_all)
scores = {"loss": loss_mean}
return scores
[docs] def save_ckpt(self, ckpt_path: str) -> None:
"""Saves the model checkpoint."""
state_dict = self.model.state_dict()
torch.save(state_dict, ckpt_path)
return
[docs] def load_ckpt(self, ckpt_path: str) -> None:
"""Saves the model checkpoint."""
state_dict = torch.load(ckpt_path, map_location=self.device)
self.model.load_state_dict(state_dict)
return
if __name__ == "__main__":
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from pyhealth.datasets.utils import collate_fn_dict
class MNISTDataset(Dataset):
def __init__(self, train=True):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
self.dataset = datasets.MNIST(
"../data", train=train, download=True, transform=transform
)
def __getitem__(self, index):
x, y = self.dataset[index]
return {"x": x, "y": y}
def __len__(self):
return len(self.dataset)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.mode = "multiclass"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.loss = nn.CrossEntropyLoss()
def forward(self, x, y, **kwargs):
x = torch.stack(x, dim=0).to(self.device)
y = torch.tensor(y).to(self.device)
x = self.conv1(x)
x = torch.relu(x)
x = self.conv2(x)
x = torch.relu(x)
x = torch.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = torch.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
loss = self.loss(x, y)
y_prob = torch.softmax(x, dim=1)
return {"loss": loss, "y_prob": y_prob, "y_true": y}
train_dataset = MNISTDataset(train=True)
val_dataset = MNISTDataset(train=False)
train_dataloader = DataLoader(
train_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=True
)
val_dataloader = DataLoader(
val_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=False
)
model = Model()
trainer = Trainer(model, device="cuda" if torch.cuda.is_available() else "cpu")
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
monitor="accuracy",
epochs=5,
test_dataloader=val_dataloader,
)