Source code for pyhealth.datasets.splitter

from itertools import chain
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from .sample_dataset import SampleDataset

# TODO: train_dataset.dataset still access the whole dataset which may leak information
# TODO: add more splitting methods


def _label_to_int(label) -> int:
    """Convert a stored label (int/np scalar/torch scalar) to Python int."""
    if torch.is_tensor(label):
        return int(label.item())
    return int(label)


[docs]def sample_balanced( dataset: SampleDataset, ratio: float = 1.0, subsample: float = 1.0, seed: Optional[int] = None, ) -> SampleDataset: """Keep positives and negatives at a target ratio, then cap total size. Args: dataset: Dataset with ``patient_to_index`` populated. ratio: Negatives per positive (e.g., 1.0 -> ~1 neg per pos). Values <=0 keep only positives. subsample: Max fraction of the original dataset size to retain. If the ratio-selected set exceeds ``len(dataset) * subsample``, both positives and negatives are downsampled proportionally while preserving the ratio as closely as possible. seed: Optional RNG seed for reproducible negative sampling. Returns: A new ``SampleDataset`` containing all positives plus sampled negatives, with refreshed ``patient_to_index`` and ``record_to_index`` mappings. """ if ratio < 0: raise ValueError("ratio must be non-negative") if subsample <= 0 or subsample > 1: raise ValueError("subsample must be in (0, 1]") rng = np.random.default_rng(seed) pos_indices: List[int] = [] neg_indices: List[int] = [] for idx in range(len(dataset)): label = _label_to_int(dataset[idx]["label"]) if label == 1: pos_indices.append(idx) else: neg_indices.append(idx) if not pos_indices: return dataset desired_pos = len(pos_indices) desired_neg = min(len(neg_indices), int(round(desired_pos * ratio))) cap = max(1, int(len(dataset) * subsample)) desired_total = desired_pos + desired_neg if desired_total <= cap: keep_pos = desired_pos keep_neg = desired_neg else: ratio_effective = desired_neg / desired_pos if desired_pos > 0 else 0.0 keep_pos = max(1, min(desired_pos, int(cap / (1 + ratio_effective)))) keep_neg = int(round(keep_pos * ratio_effective)) if ratio_effective > 0 else 0 keep_neg = min(keep_neg, len(neg_indices)) if keep_pos + keep_neg > cap: keep_neg = max(0, cap - keep_pos) if keep_pos < desired_pos: pos_keep = list(rng.choice(pos_indices, size=keep_pos, replace=False)) else: pos_keep = pos_indices if keep_neg > 0: neg_keep = list(rng.choice(neg_indices, size=keep_neg, replace=False)) else: neg_keep = [] keep_indices = pos_keep + neg_keep balanced = dataset.subset(keep_indices) # type: ignore # Rebuild patient_to_index and record_to_index for the reduced set. balanced.patient_to_index = {} balanced.record_to_index = {} for i in range(len(balanced)): sample = balanced[i] pid = sample.get("patient_id") rid = sample.get("record_id", sample.get("visit_id")) if pid is not None: balanced.patient_to_index.setdefault(pid, []).append(i) if rid is not None: balanced.record_to_index.setdefault(rid, []).append(i) return balanced
[docs]def split_by_visit( dataset: SampleDataset, ratios: Union[Tuple[float, float, float], List[float]], seed: Optional[int] = None, ): """Splits the dataset by visit (i.e., samples). Args: dataset: a `SampleDataset` object ratios: a list/tuple of ratios for train / val / test seed: random seed for shuffling the dataset Returns: train_dataset, val_dataset, test_dataset: three subsets of the dataset of type `torch.utils.data.Subset`. Note: The original dataset can be accessed by `train_dataset.dataset`, `val_dataset.dataset`, and `test_dataset.dataset`. """ rng = np.random.default_rng(seed) assert sum(ratios) == 1.0, "ratios must sum to 1.0" index = np.arange(len(dataset)) rng.shuffle(index) train_index = index[: int(len(dataset) * ratios[0])] val_index = index[ int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] train_dataset = dataset.subset(train_index) # type: ignore val_dataset = dataset.subset(val_index) # type: ignore test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, test_dataset
[docs]def split_by_patient( dataset: SampleDataset, ratios: Union[Tuple[float, float, float], List[float]], seed: Optional[int] = None, ): """Splits the dataset by patient. Args: dataset: a `SampleDataset` object ratios: a list/tuple of ratios for train / val / test seed: random seed for shuffling the dataset Returns: train_dataset, val_dataset, test_dataset: three subsets of the dataset of type `torch.utils.data.Subset`. Note: The original dataset can be accessed by `train_dataset.dataset`, `val_dataset.dataset`, and `test_dataset.dataset`. """ rng = np.random.default_rng(seed) assert sum(ratios) == 1.0, "ratios must sum to 1.0" patient_indx = list(dataset.patient_to_index.keys()) num_patients = len(patient_indx) rng.shuffle(patient_indx) train_patient_indx = patient_indx[: int(num_patients * ratios[0])] val_patient_indx = patient_indx[ int(num_patients * ratios[0]) : int(num_patients * (ratios[0] + ratios[1])) ] test_patient_indx = patient_indx[int(num_patients * (ratios[0] + ratios[1])) :] train_index = list( chain(*[dataset.patient_to_index[i] for i in train_patient_indx]) ) val_index = list(chain(*[dataset.patient_to_index[i] for i in val_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) train_dataset = dataset.subset(train_index) # type: ignore val_dataset = dataset.subset(val_index) # type: ignore test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, test_dataset
[docs]def split_by_sample( dataset: SampleDataset, ratios: Union[Tuple[float, float, float], List[float]], seed: Optional[int] = None, get_index: Optional[bool] = False, ): """Splits the dataset by sample Args: dataset: a `SampleDataset` object ratios: a list/tuple of ratios for train / val / test seed: random seed for shuffling the dataset Returns: train_dataset, val_dataset, test_dataset: three subsets of the dataset of type `torch.utils.data.Subset`. Note: The original dataset can be accessed by `train_dataset.dataset`, `val_dataset.dataset`, and `test_dataset.dataset`. """ rng = np.random.default_rng(seed) assert sum(ratios) == 1.0, "ratios must sum to 1.0" index = np.arange(len(dataset)) rng.shuffle(index) train_index = index[: int(len(dataset) * ratios[0])] val_index = index[ int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) ] test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] train_dataset = dataset.subset(train_index) # type: ignore val_dataset = dataset.subset(val_index) # type: ignore test_dataset = dataset.subset(test_index) # type: ignore if get_index: return ( torch.tensor(train_index), torch.tensor(val_index), torch.tensor(test_index), ) else: return train_dataset, val_dataset, test_dataset
[docs]def split_by_visit_conformal( dataset: SampleDataset, ratios: Union[Tuple[float, float, float, float], List[float]], seed: Optional[int] = None, ): """Splits the dataset by visit (i.e., samples) for conformal prediction. Args: dataset: a `SampleDataset` object ratios: a list/tuple of ratios for train / val / cal / test seed: random seed for shuffling the dataset Returns: train_dataset, val_dataset, cal_dataset, test_dataset: four subsets of the dataset of type `torch.utils.data.Subset`. Note: The original dataset can be accessed by `train_dataset.dataset`, `val_dataset.dataset`, `cal_dataset.dataset`, and `test_dataset.dataset`. """ rng = np.random.default_rng(seed) assert len(ratios) == 4, "ratios must have 4 elements for train/val/cal/test" assert sum(ratios) == 1.0, "ratios must sum to 1.0" index = np.arange(len(dataset)) rng.shuffle(index) # Calculate split points train_end = int(len(dataset) * ratios[0]) val_end = int(len(dataset) * (ratios[0] + ratios[1])) cal_end = int(len(dataset) * (ratios[0] + ratios[1] + ratios[2])) train_index = index[:train_end] val_index = index[train_end:val_end] cal_index = index[val_end:cal_end] test_index = index[cal_end:] train_dataset = dataset.subset(train_index) # type: ignore val_dataset = dataset.subset(val_index) # type: ignore cal_dataset = dataset.subset(cal_index) # type: ignore test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset
[docs]def split_by_patient_conformal( dataset: SampleDataset, ratios: Union[Tuple[float, float, float, float], List[float]], seed: Optional[int] = None, ): """Splits the dataset by patient for conformal prediction. Args: dataset: a `SampleDataset` object ratios: a list/tuple of ratios for train / val / cal / test seed: random seed for shuffling the dataset Returns: train_dataset, val_dataset, cal_dataset, test_dataset: four subsets of the dataset of type `torch.utils.data.Subset`. Note: The original dataset can be accessed by `train_dataset.dataset`, `val_dataset.dataset`, `cal_dataset.dataset`, and `test_dataset.dataset`. """ rng = np.random.default_rng(seed) assert len(ratios) == 4, "ratios must have 4 elements for train/val/cal/test" assert sum(ratios) == 1.0, "ratios must sum to 1.0" patient_indx = list(dataset.patient_to_index.keys()) num_patients = len(patient_indx) rng.shuffle(patient_indx) # Calculate split points train_end = int(num_patients * ratios[0]) val_end = int(num_patients * (ratios[0] + ratios[1])) cal_end = int(num_patients * (ratios[0] + ratios[1] + ratios[2])) train_patient_indx = patient_indx[:train_end] val_patient_indx = patient_indx[train_end:val_end] cal_patient_indx = patient_indx[val_end:cal_end] test_patient_indx = patient_indx[cal_end:] train_index = list( chain(*[dataset.patient_to_index[i] for i in train_patient_indx]) ) val_index = list(chain(*[dataset.patient_to_index[i] for i in val_patient_indx])) cal_index = list(chain(*[dataset.patient_to_index[i] for i in cal_patient_indx])) test_index = list(chain(*[dataset.patient_to_index[i] for i in test_patient_indx])) train_dataset = dataset.subset(train_index) # type: ignore val_dataset = dataset.subset(val_index) # type: ignore cal_dataset = dataset.subset(cal_index) # type: ignore test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset
[docs]def split_by_patient_conformal_tuh( dataset: SampleDataset, ratios: Union[Tuple[float, float, float], List[float]], seed: Optional[int] = None, get_index: Optional[bool] = False, ): """Splits a TUH EEG dataset by PATIENT within the TUH train partition. Unlike :func:`split_by_sample_conformal_tuh`, which shuffles windows independently, this function ensures that **all windows from a given patient stay in the same split** (train, val, or cal). This prevents within-patient data leakage between training and calibration — a critical requirement for honest evaluation of conformal prediction methods, especially NCP, which relies on nearest-neighbor search in embedding space. The TUH corpus is designed so that each patient appears exclusively in either the training partition (265 patients for TUSZ) or the evaluation partition (50 patients), never both. The test set is therefore always the TUH eval partition (fixed, fully held-out patients). Args: dataset: a ``SampleDataset`` produced by ``EEGEventsTUEV`` or ``EEGAbnormalTUAB``. Each sample must carry a ``"split"`` field (``"train"`` or ``"eval"``) and the dataset must have a ``patient_to_index`` mapping. ratios: fraction of *patients* in the TUH train partition to assign to train / val / cal respectively. Must be a length-3 sequence summing to 1.0. seed: random seed used to shuffle the patient list. get_index: if ``True``, return four :class:`torch.Tensor` index vectors instead of :class:`~torch.utils.data.Subset` objects. Returns: ``(train_dataset, val_dataset, cal_dataset, test_dataset)`` """ assert len(ratios) == 3, ( "ratios must have exactly 3 elements (train/val/cal). " "The test set is determined by the TUH eval partition." ) assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0" # Bucket patients by partition using dataset.patient_to_index (fast path). # TUH guarantees each patient is in exactly one partition, so inspecting the # first sample per patient is sufficient. train_patient_to_indices: dict = {} test_list: List[int] = [] for pid, indices in dataset.patient_to_index.items(): first_sample = dataset[indices[0]] assert "split" in first_sample, ( f"Patient {pid}: sample missing 'split' field. " "Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." ) if first_sample["split"] == "train": train_patient_to_indices[pid] = list(indices) else: test_list.extend(list(indices)) # Shuffle patients deterministically patient_ids = list(train_patient_to_indices.keys()) rng = np.random.default_rng(seed) rng.shuffle(patient_ids) n = len(patient_ids) train_end = int(n * ratios[0]) val_end = int(n * (ratios[0] + ratios[1])) train_pids = patient_ids[:train_end] val_pids = patient_ids[train_end:val_end] cal_pids = patient_ids[val_end:] train_index = np.array(list(chain(*[train_patient_to_indices[p] for p in train_pids]))) val_index = np.array(list(chain(*[train_patient_to_indices[p] for p in val_pids]))) cal_index = np.array(list(chain(*[train_patient_to_indices[p] for p in cal_pids]))) test_index = np.array(test_list) if get_index: return ( torch.tensor(train_index), torch.tensor(val_index), torch.tensor(cal_index), torch.tensor(test_index), ) else: return ( dataset.subset(train_index), # type: ignore dataset.subset(val_index), # type: ignore dataset.subset(cal_index), # type: ignore dataset.subset(test_index), # type: ignore )
[docs]def split_by_sample_conformal_tuh( dataset: SampleDataset, ratios: Union[Tuple[float, float, float], List[float]], seed: Optional[int] = None, get_index: Optional[bool] = False, ): """Splits a TUH EEG dataset (TUEV/TUAB) using its pre-defined train/eval split. Args: dataset: a ``SampleDataset`` object produced by ``EEGEventsTUEV`` or ``EEGAbnormalTUAB`` ratios: the fraction of the train pool assigned to train / val / cal respectively seed: random seed for shuffling the train pool get_index: if True, return four ``torch.Tensor`` index vectors instead of ``Subset`` objects Returns: train_dataset, val_dataset, cal_dataset, test_dataset """ assert len(ratios) == 3, ( "ratios must have exactly 3 elements (train/val/cal). " "The test set is determined by the dataset's own eval partition." ) assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0" # verify every sample has the required "split" field for i in range(len(dataset)): assert "split" in dataset[i], ( f"Sample {i} is missing the 'split' field. " "Make sure you used EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." ) train_pool: List[int] = [] test_list: List[int] = [] for i in range(len(dataset)): if dataset[i]["split"] == "train": train_pool.append(i) else: test_list.append(i) # shuffle only the train pool rng = np.random.default_rng(seed) train_arr = np.array(train_pool) rng.shuffle(train_arr) # Slice into train / val / cal. n = len(train_arr) train_end = int(n * ratios[0]) val_end = int(n * (ratios[0] + ratios[1])) train_index = train_arr[:train_end] val_index = train_arr[train_end:val_end] cal_index = train_arr[val_end:] test_index = np.array(test_list) if get_index: return ( torch.tensor(train_index), torch.tensor(val_index), torch.tensor(cal_index), torch.tensor(test_index), ) else: return ( dataset.subset(train_index), # type: ignore dataset.subset(val_index), # type: ignore dataset.subset(cal_index), # type: ignore dataset.subset(test_index), # type: ignore )
[docs]def split_by_patient_tuh( dataset: SampleDataset, ratios: Union[Tuple[float, float], List[float]], seed: Optional[int] = None, get_index: Optional[bool] = False, ): """Splits a TUH EEG dataset by PATIENT within the TUH train partition. Like :func:`split_by_patient_conformal_tuh` but returns a 3-way split (train / val / test) instead of 4-way (train / val / cal / test). The test set is always the TUH eval partition; no calibration set is produced. Ensures that **all windows from a given patient stay in the same split** (train or val), preventing within-patient data leakage. Args: dataset: a ``SampleDataset`` produced by ``EEGEventsTUEV`` or ``EEGAbnormalTUAB``. Each sample must carry a ``"split"`` field (``"train"`` or ``"eval"``) and the dataset must have a ``patient_to_index`` mapping. ratios: fraction of *patients* in the TUH train partition to assign to train / val respectively. Must be a length-2 sequence summing to 1.0. seed: random seed used to shuffle the patient list. get_index: if ``True``, return three :class:`torch.Tensor` index vectors instead of :class:`~torch.utils.data.Subset` objects. Returns: ``(train_dataset, val_dataset, test_dataset)`` """ assert len(ratios) == 2, ( "ratios must have exactly 2 elements (train/val). " "The test set is determined by the TUH eval partition." ) assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0" train_patient_to_indices: dict = {} test_list: List[int] = [] for pid, indices in dataset.patient_to_index.items(): first_sample = dataset[indices[0]] assert "split" in first_sample, ( f"Patient {pid}: sample missing 'split' field. " "Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." ) if first_sample["split"] == "train": train_patient_to_indices[pid] = list(indices) else: test_list.extend(list(indices)) patient_ids = list(train_patient_to_indices.keys()) rng = np.random.default_rng(seed) rng.shuffle(patient_ids) n = len(patient_ids) train_end = int(n * ratios[0]) train_pids = patient_ids[:train_end] val_pids = patient_ids[train_end:] train_index = np.array( list(chain(*[train_patient_to_indices[p] for p in train_pids])) ) val_index = np.array( list(chain(*[train_patient_to_indices[p] for p in val_pids])) ) test_index = np.array(test_list) if get_index: return ( torch.tensor(train_index), torch.tensor(val_index), torch.tensor(test_index), ) else: return ( dataset.subset(train_index), # type: ignore dataset.subset(val_index), # type: ignore dataset.subset(test_index), # type: ignore )
[docs]def split_by_sample_tuh( dataset: SampleDataset, ratios: Union[Tuple[float, float], List[float]], seed: Optional[int] = None, get_index: Optional[bool] = False, ): """Splits a TUH EEG dataset (TUEV/TUAB) using its pre-defined split. Like :func:`split_by_sample_conformal_tuh` but returns a 3-way split (train / val / test) instead of 4-way (train / val / cal / test). The test set is always the TUH eval partition; no calibration set is produced. Args: dataset: a ``SampleDataset`` object produced by ``EEGEventsTUEV`` or ``EEGAbnormalTUAB`` ratios: the fraction of the train pool assigned to train / val respectively. Must be a length-2 sequence summing to 1.0. seed: random seed for shuffling the train pool get_index: if True, return three ``torch.Tensor`` index vectors instead of ``Subset`` objects Returns: train_dataset, val_dataset, test_dataset """ assert len(ratios) == 2, ( "ratios must have exactly 2 elements (train/val). " "The test set is determined by the dataset's own eval partition." ) assert abs(sum(ratios) - 1.0) < 1e-6, "ratios must sum to 1.0" for i in range(len(dataset)): assert "split" in dataset[i], ( f"Sample {i} is missing the 'split' field. " "Use EEGEventsTUEV or EEGAbnormalTUAB to build the dataset." ) train_pool: List[int] = [] test_list: List[int] = [] for i in range(len(dataset)): if dataset[i]["split"] == "train": train_pool.append(i) else: test_list.append(i) rng = np.random.default_rng(seed) train_arr = np.array(train_pool) rng.shuffle(train_arr) n = len(train_arr) train_end = int(n * ratios[0]) train_index = train_arr[:train_end] val_index = train_arr[train_end:] test_index = np.array(test_list) if get_index: return ( torch.tensor(train_index), torch.tensor(val_index), torch.tensor(test_index), ) else: return ( dataset.subset(train_index), # type: ignore dataset.subset(val_index), # type: ignore dataset.subset(test_index), # type: ignore )
[docs]def split_by_sample_conformal( dataset: SampleDataset, ratios: Union[Tuple[float, float, float, float], List[float]], seed: Optional[int] = None, get_index: Optional[bool] = False, ): """Splits the dataset by sample for conformal prediction. Args: dataset: a `SampleDataset` object ratios: a list/tuple of ratios for train / val / cal / test seed: random seed for shuffling the dataset get_index: if True, return indices instead of Subset objects Returns: train_dataset, val_dataset, cal_dataset, test_dataset: four subsets of the dataset of type `torch.utils.data.Subset`, or four tensors of indices if get_index=True. Note: The original dataset can be accessed by `train_dataset.dataset`, `val_dataset.dataset`, `cal_dataset.dataset`, and `test_dataset.dataset`. """ rng = np.random.default_rng(seed) assert len(ratios) == 4, "ratios must have 4 elements for train/val/cal/test" assert sum(ratios) == 1.0, "ratios must sum to 1.0" index = np.arange(len(dataset)) rng.shuffle(index) # Calculate split points train_end = int(len(dataset) * ratios[0]) val_end = int(len(dataset) * (ratios[0] + ratios[1])) cal_end = int(len(dataset) * (ratios[0] + ratios[1] + ratios[2])) train_index = index[:train_end] val_index = index[train_end:val_end] cal_index = index[val_end:cal_end] test_index = index[cal_end:] if get_index: return ( torch.tensor(train_index), torch.tensor(val_index), torch.tensor(cal_index), torch.tensor(test_index), ) else: train_dataset = dataset.subset(train_index) # type: ignore val_dataset = dataset.subset(val_index) # type: ignore cal_dataset = dataset.subset(cal_index) # type: ignore test_dataset = dataset.subset(test_index) # type: ignore return train_dataset, val_dataset, cal_dataset, test_dataset