pyhealth.datasets.splitter#
Several data splitting function for pyhealth.datasets module to obtain training / validation / test sets.
- pyhealth.datasets.splitter.split_by_visit(dataset, ratios, seed=None)[source]#
Splits the dataset by visit (i.e., samples).
- Parameters:
- Returns:
- three subsets of the dataset of
type torch.utils.data.Subset.
- Return type:
train_dataset, val_dataset, test_dataset
Note
- The original dataset can be accessed by train_dataset.dataset,
val_dataset.dataset, and test_dataset.dataset.
- pyhealth.datasets.splitter.split_by_patient(dataset, ratios, seed=None)[source]#
Splits the dataset by patient.
- Parameters:
- Returns:
- three subsets of the dataset of
type torch.utils.data.Subset.
- Return type:
train_dataset, val_dataset, test_dataset
Note
- The original dataset can be accessed by train_dataset.dataset,
val_dataset.dataset, and test_dataset.dataset.
- pyhealth.datasets.splitter.split_by_sample(dataset, ratios, seed=None, get_index=False)[source]#
Splits the dataset by sample
- Parameters:
- Returns:
- three subsets of the dataset of
type torch.utils.data.Subset.
- Return type:
train_dataset, val_dataset, test_dataset
Note
- The original dataset can be accessed by train_dataset.dataset,
val_dataset.dataset, and test_dataset.dataset.