pyhealth.datasets.splitter#

Several data splitting function for pyhealth.datasets module to obtain training / validation / test sets.

class pyhealth.datasets.splitter.chain#

Bases: object

chain(*iterables) –> chain object

Return a chain object whose .__next__() method returns elements from the first iterable until it is exhausted, then elements from the next iterable, until all of the iterables are exhausted.

from_iterable()#

Alternative chain() constructor taking a single iterable argument that evaluates lazily.

class pyhealth.datasets.splitter.SampleDataset(path, dataset_name=None, task_name=None, **kwargs)[source]#

Bases: StreamingDataset

A streaming dataset that loads sample metadata and processors from disk.

SampleDataset expects the path directory to contain a schema.pkl file created by a SampleBuilder.save(…) call. The schema.pkl must include the fitted input_schema, output_schema, input_processors, output_processors, patient_to_index and record_to_index mappings.

input_schema#

The configuration used to instantiate processors for input features (string aliases or processor specs).

output_schema#

The configuration used to instantiate processors for output features.

input_processors#

A mapping of input feature names to fitted FeatureProcessor instances.

output_processors#

A mapping of output feature names to fitted FeatureProcessor instances.

patient_to_index#

Dictionary mapping patient IDs to the list of sample indices associated with that patient.

record_to_index#

Dictionary mapping record/visit IDs to the list of sample indices associated with that record.

dataset_name#

Optional human friendly dataset name.

task_name#

Optional human friendly task name.

subset(indices)[source]#

Create a StreamingDataset restricted to the provided indices.

Return type:

SampleDataset

close()[source]#

Cleans up any temporary directories used by the dataset.

Return type:

None

get_len(num_workers, batch_size)#
Return type:

int

load_state_dict(state_dict)#
Return type:

None

property on_demand_bytes: bool#
Return type:

bool

reset()#
Return type:

None

reset_state_dict()#
Return type:

None

set_batch_size(batch_size)#
Return type:

None

set_drop_last(drop_last)#

Set the drop_last parameter.

Invalidates the shuffler cache when the parameter changes to ensure subsequent length calculations reflect the new drop_last setting.

Parameters:

drop_last (bool) – Whether to drop the last incomplete batch.

Return type:

None

set_epoch(current_epoch)#

Set the current epoch to the dataset on epoch starts.

When using the StreamingDataLoader, this is done automatically

Return type:

None

set_num_workers(num_workers)#
Return type:

None

set_shuffle(shuffle)#

Set the shuffle parameter.

Invalidates the shuffler cache when the parameter changes to ensure subsequent length calculations reflect the new shuffle setting.

Parameters:

shuffle (bool) – Whether to shuffle the dataset.

Return type:

None

state_dict(num_samples_yielded, num_workers, batch_size)#
Return type:

dict[str, Any]

pyhealth.datasets.splitter.sample_balanced(dataset, ratio=1.0, subsample=1.0, seed=None)[source]#

Keep positives and negatives at a target ratio, then cap total size.

Parameters:
  • dataset (SampleDataset) – Dataset with patient_to_index populated.

  • ratio (float) – Negatives per positive (e.g., 1.0 -> ~1 neg per pos). Values <=0 keep only positives.

  • subsample (float) – Max fraction of the original dataset size to retain. If the ratio-selected set exceeds len(dataset) * subsample, both positives and negatives are downsampled proportionally while preserving the ratio as closely as possible.

  • seed (Optional[int]) – Optional RNG seed for reproducible negative sampling.

Return type:

SampleDataset

Returns:

A new SampleDataset containing all positives plus sampled negatives, with refreshed patient_to_index and record_to_index mappings.

pyhealth.datasets.splitter.split_by_visit(dataset, ratios, seed=None)[source]#

Splits the dataset by visit (i.e., samples).

Parameters:
Returns:

three subsets of the dataset of

type torch.utils.data.Subset.

Return type:

train_dataset, val_dataset, test_dataset

Note

The original dataset can be accessed by train_dataset.dataset,

val_dataset.dataset, and test_dataset.dataset.

pyhealth.datasets.splitter.split_by_patient(dataset, ratios, seed=None)[source]#

Splits the dataset by patient.

Parameters:
Returns:

three subsets of the dataset of

type torch.utils.data.Subset.

Return type:

train_dataset, val_dataset, test_dataset

Note

The original dataset can be accessed by train_dataset.dataset,

val_dataset.dataset, and test_dataset.dataset.

pyhealth.datasets.splitter.split_by_sample(dataset, ratios, seed=None, get_index=False)[source]#

Splits the dataset by sample

Parameters:
Returns:

three subsets of the dataset of

type torch.utils.data.Subset.

Return type:

train_dataset, val_dataset, test_dataset

Note

The original dataset can be accessed by train_dataset.dataset,

val_dataset.dataset, and test_dataset.dataset.

pyhealth.datasets.splitter.split_by_visit_conformal(dataset, ratios, seed=None)[source]#

Splits the dataset by visit (i.e., samples) for conformal prediction.

Parameters:
Returns:

four subsets

of the dataset of type torch.utils.data.Subset.

Return type:

train_dataset, val_dataset, cal_dataset, test_dataset

Note

The original dataset can be accessed by train_dataset.dataset,

val_dataset.dataset, cal_dataset.dataset, and test_dataset.dataset.

pyhealth.datasets.splitter.split_by_patient_conformal(dataset, ratios, seed=None)[source]#

Splits the dataset by patient for conformal prediction.

Parameters:
Returns:

four subsets

of the dataset of type torch.utils.data.Subset.

Return type:

train_dataset, val_dataset, cal_dataset, test_dataset

Note

The original dataset can be accessed by train_dataset.dataset,

val_dataset.dataset, cal_dataset.dataset, and test_dataset.dataset.

pyhealth.datasets.splitter.split_by_patient_conformal_tuh(dataset, ratios, seed=None, get_index=False)[source]#

Splits a TUH EEG dataset by PATIENT within the TUH train partition.

Unlike split_by_sample_conformal_tuh(), which shuffles windows independently, this function ensures that all windows from a given patient stay in the same split (train, val, or cal). This prevents within-patient data leakage between training and calibration — a critical requirement for honest evaluation of conformal prediction methods, especially NCP, which relies on nearest-neighbor search in embedding space.

The TUH corpus is designed so that each patient appears exclusively in either the training partition (265 patients for TUSZ) or the evaluation partition (50 patients), never both. The test set is therefore always the TUH eval partition (fixed, fully held-out patients).

Parameters:
  • dataset (SampleDataset) – a SampleDataset produced by EEGEventsTUEV or EEGAbnormalTUAB. Each sample must carry a "split" field ("train" or "eval") and the dataset must have a patient_to_index mapping.

  • ratios (Union[Tuple[float, float, float], List[float]]) – fraction of patients in the TUH train partition to assign to train / val / cal respectively. Must be a length-3 sequence summing to 1.0.

  • seed (Optional[int]) – random seed used to shuffle the patient list.

  • get_index (Optional[bool]) – if True, return four torch.Tensor index vectors instead of Subset objects.

Returns:

(train_dataset, val_dataset, cal_dataset, test_dataset)

pyhealth.datasets.splitter.split_by_sample_conformal_tuh(dataset, ratios, seed=None, get_index=False)[source]#

Splits a TUH EEG dataset (TUEV/TUAB) using its pre-defined train/eval split.

Parameters:
  • dataset (SampleDataset) – a SampleDataset object produced by EEGEventsTUEV or EEGAbnormalTUAB

  • ratios (Union[Tuple[float, float, float], List[float]]) – the fraction of the train pool assigned to train / val / cal respectively

  • seed (Optional[int]) – random seed for shuffling the train pool

  • get_index (Optional[bool]) – if True, return four torch.Tensor index vectors instead of Subset objects

Returns:

train_dataset, val_dataset, cal_dataset, test_dataset

pyhealth.datasets.splitter.split_by_patient_tuh(dataset, ratios, seed=None, get_index=False)[source]#

Splits a TUH EEG dataset by PATIENT within the TUH train partition.

Like split_by_patient_conformal_tuh() but returns a 3-way split (train / val / test) instead of 4-way (train / val / cal / test). The test set is always the TUH eval partition; no calibration set is produced.

Ensures that all windows from a given patient stay in the same split (train or val), preventing within-patient data leakage.

Parameters:
  • dataset (SampleDataset) – a SampleDataset produced by EEGEventsTUEV or EEGAbnormalTUAB. Each sample must carry a "split" field ("train" or "eval") and the dataset must have a patient_to_index mapping.

  • ratios (Union[Tuple[float, float], List[float]]) – fraction of patients in the TUH train partition to assign to train / val respectively. Must be a length-2 sequence summing to 1.0.

  • seed (Optional[int]) – random seed used to shuffle the patient list.

  • get_index (Optional[bool]) – if True, return three torch.Tensor index vectors instead of Subset objects.

Returns:

(train_dataset, val_dataset, test_dataset)

pyhealth.datasets.splitter.split_by_sample_tuh(dataset, ratios, seed=None, get_index=False)[source]#

Splits a TUH EEG dataset (TUEV/TUAB) using its pre-defined split.

Like split_by_sample_conformal_tuh() but returns a 3-way split (train / val / test) instead of 4-way (train / val / cal / test). The test set is always the TUH eval partition; no calibration set is produced.

Parameters:
  • dataset (SampleDataset) – a SampleDataset object produced by EEGEventsTUEV or EEGAbnormalTUAB

  • ratios (Union[Tuple[float, float], List[float]]) – the fraction of the train pool assigned to train / val respectively. Must be a length-2 sequence summing to 1.0.

  • seed (Optional[int]) – random seed for shuffling the train pool

  • get_index (Optional[bool]) – if True, return three torch.Tensor index vectors instead of Subset objects

Returns:

train_dataset, val_dataset, test_dataset

pyhealth.datasets.splitter.split_by_sample_conformal(dataset, ratios, seed=None, get_index=False)[source]#

Splits the dataset by sample for conformal prediction.

Parameters:
Returns:

four subsets

of the dataset of type torch.utils.data.Subset, or four tensors of indices if get_index=True.

Return type:

train_dataset, val_dataset, cal_dataset, test_dataset

Note

The original dataset can be accessed by train_dataset.dataset,

val_dataset.dataset, cal_dataset.dataset, and test_dataset.dataset.