pyhealth.datasets.splitter#

Several data splitting function for pyhealth.datasets module to obtain training / validation / test sets.

pyhealth.datasets.splitter.split_by_visit(dataset, ratios, seed=None)[source]#

Splits the dataset by visit (i.e., samples).

Parameters:
Returns:

three subsets of the dataset of

type torch.utils.data.Subset.

Return type:

train_dataset, val_dataset, test_dataset

Note

The original dataset can be accessed by train_dataset.dataset,

val_dataset.dataset, and test_dataset.dataset.

pyhealth.datasets.splitter.split_by_patient(dataset, ratios, seed=None)[source]#

Splits the dataset by patient.

Parameters:
Returns:

three subsets of the dataset of

type torch.utils.data.Subset.

Return type:

train_dataset, val_dataset, test_dataset

Note

The original dataset can be accessed by train_dataset.dataset,

val_dataset.dataset, and test_dataset.dataset.

pyhealth.datasets.splitter.split_by_sample(dataset, ratios, seed=None, get_index=False)[source]#

Splits the dataset by sample

Parameters:
Returns:

three subsets of the dataset of

type torch.utils.data.Subset.

Return type:

train_dataset, val_dataset, test_dataset

Note

The original dataset can be accessed by train_dataset.dataset,

val_dataset.dataset, and test_dataset.dataset.