Source code for pyhealth.models.base_model

from abc import ABC
from typing import Callable, Any, Optional
import inspect

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..datasets import SampleDataset
from ..processors import PROCESSOR_REGISTRY


[docs]class BaseModel(ABC, nn.Module): """Abstract class for PyTorch models. Args: dataset (SampleDataset): The dataset to train the model. It is used to query certain information such as the set of all tokens. Interpretability -------- To use a model with interpretability methods, the model must implement a method `forward_from_embedding` that takes in embeddings as input instead of raw features; for the models that already take in dense features as input, this method can simply call the existing `forward` method. For certain gradient-based interpretability methods (e.g., DeepLIFT), the model must also ensure all non-linearity (e.g. ReLU, Sigmoid, Softmax) are using nn.Module versions instead of functional versions (e.g., F.relu, F.sigmoid, F.softmax) so that hooks can be registered properly. """ def __init__(self, dataset: SampleDataset): """ Initializes the BaseModel. Args: dataset (SampleDataset): The dataset to train the model. """ super(BaseModel, self).__init__() self.dataset = dataset self.feature_keys = [] self.label_keys = [] if dataset: self.feature_keys = list(dataset.input_schema.keys()) self.label_keys = list(dataset.output_schema.keys()) # if single label, try to resolve mode for legacy trainer usage if len(self.label_keys) == 1: try: m = self._resolve_mode(dataset.output_schema[self.label_keys[0]]) if m in {"binary", "multiclass", "multilabel", "regression"}: self.mode = m except Exception: pass # used to query the device of the model self._dummy_param = nn.Parameter(torch.empty(0)) self.mode = getattr(self, "mode", None) # legacy API
[docs] def forward(self, **kwargs: torch.Tensor | tuple[torch.Tensor, ...] ) -> dict[str, torch.Tensor]: """Forward pass of the model. Args: **kwargs: A variable number of keyword arguments representing input features. Each keyword argument is a tensor or a tuple of tensors of shape (batch_size, ...). Returns: A dictionary with the following keys: logit: a tensor of predicted logits. y_prob: a tensor of predicted probabilities. loss [optional]: a scalar tensor representing the final loss, if self.label_keys in kwargs. y_true [optional]: a tensor representing the true labels, if self.label_keys in kwargs. """ raise NotImplementedError
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _resolve_mode(self, schema_entry: Any) -> str: """Resolve a mode string from an output_schema entry. Supports: - direct string ("binary", ...) - processor class - processor instance Returns the registered processor name if found. """ if isinstance(schema_entry, str): return schema_entry.lower() # Get class reference cls = schema_entry if inspect.isclass(schema_entry) else schema_entry.__class__ for name, registered_cls in PROCESSOR_REGISTRY.items(): if cls is registered_cls or issubclass( cls, registered_cls ): # allow subclassing return name.lower() raise ValueError( f"Cannot resolve mode from output_schema entry {schema_entry}. Use a supported string" ) @property def device(self) -> torch.device: """ Gets the device of the model. Returns: torch.device: The device on which the model is located. """ return self._dummy_param.device
[docs] def get_output_size(self) -> int: """ Gets the default output size using the label tokenizer and `self.mode`. If the mode is "binary", the output size is 1. If the mode is "multiclass" or "multilabel", the output size is the number of classes or labels. Returns: int: The output size of the model. """ assert ( len(self.label_keys) == 1 ), "Only one label key is supported if get_output_size is called" output_size = self.dataset.output_processors[self.label_keys[0]].size() return output_size
[docs] def get_loss_function(self) -> Callable: """ Gets the default loss function using `self.mode`. The default loss functions are: - binary: `F.binary_cross_entropy_with_logits` - multiclass: `F.cross_entropy` - multilabel: `F.binary_cross_entropy_with_logits` - regression: `F.mse_loss` Returns: Callable: The default loss function. """ assert ( len(self.label_keys) == 1 ), "Only one label key is supported if get_loss_function is called" label_key = self.label_keys[0] mode = self._resolve_mode(self.dataset.output_schema[label_key]) if mode == "binary": return F.binary_cross_entropy_with_logits elif mode == "multiclass": return F.cross_entropy elif mode == "multilabel": return F.binary_cross_entropy_with_logits elif mode == "regression": return F.mse_loss else: raise ValueError(f"Invalid mode: {mode}")
[docs] def prepare_y_prob(self, logits: torch.Tensor) -> torch.Tensor: """ Prepares the predicted probabilities for model evaluation. This function converts the predicted logits to predicted probabilities depending on the mode. The default formats are: - binary: a tensor of shape (batch_size, 1) with values in [0, 1], which is obtained with `torch.sigmoid()` - multiclass: a tensor of shape (batch_size, num_classes) with values in [0, 1] and sum to 1, which is obtained with `torch.softmax()` - multilabel: a tensor of shape (batch_size, num_labels) with values in [0, 1], which is obtained with `torch.sigmoid()` - regression: a tensor of shape (batch_size, 1) with raw logits Args: logits (torch.Tensor): The predicted logit tensor. Returns: torch.Tensor: The predicted probability tensor. """ assert ( len(self.label_keys) == 1 ), "Only one label key is supported if get_loss_function is called" label_key = self.label_keys[0] mode = self._resolve_mode(self.dataset.output_schema[label_key]) if mode in ["binary"]: y_prob = torch.sigmoid(logits) elif mode in ["multiclass"]: y_prob = F.softmax(logits, dim=-1) elif mode in ["multilabel"]: y_prob = torch.sigmoid(logits) elif mode in ["regression"]: y_prob = logits else: raise NotImplementedError return y_prob