pyhealth.datasets.BaseSignalDataset#
This is the basic Signal dataset class. Any specific Signal dataset will inherit from this class.
- class pyhealth.datasets.BaseSignalDataset(root, dataset_name=None, dev=False, refresh_cache=False, **kwargs)[source]#
Bases:
ABC
Abstract base Signal dataset class.
This abstract class defines a uniform interface for all EEG datasets (e.g., SleepEDF, SHHS).
Each specific dataset will be a subclass of this abstract class, which can then be converted to samples dataset for different tasks by calling self.set_task().
- Parameters:
root (
str
) – root directory of the raw data (should contain many csv files).dev (
bool
) – whether to enable dev mode (only use a small subset of the data). Default is False.refresh_cache (
bool
) – whether to refresh the cache; if true, the dataset will be processed from scratch and the cache will be updated. Default is False.
- set_task(task_fn, task_name=None)[source]#
Processes the base dataset to generate the task-specific sample dataset.
This function should be called by the user after the base dataset is initialized. It will iterate through all patients in the base dataset and call task_fn which should be implemented by the specific task.
- Parameters:
task_fn (
Callable
) – a function that takes a single patient and returns a list of samples (each sample is a dict with patient_id, visit_id, and other task-specific attributes as key). The samples will be concatenated to form the sample dataset.task_name (
Optional
[str
]) – the name of the task. If None, the name of the task function will be used.
- Returns:
the task-specific sample (Base) dataset.
- Return type:
sample_dataset
Note
- In task_fn, a patient may be converted to multiple samples, e.g.,
a patient with three visits may be converted to three samples ([visit 1], [visit 1, visit 2], [visit 1, visit 2, visit 3]). Patients can also be excluded from the task dataset by returning an empty list.