pyhealth.datasets.BaseEHRDataset#
This is the basic EHR dataset class. Any specific EHR dataset will inherit from this class.
- class pyhealth.datasets.BaseEHRDataset(root, tables, dataset_name=None, code_mapping=None, dev=False, refresh_cache=False)[source]#
Bases:
ABC
Abstract base dataset class.
This abstract class defines a uniform interface for all EHR datasets (e.g., MIMIC-III, MIMIC-IV, eICU, OMOP).
Each specific dataset will be a subclass of this abstract class, which can then be converted to samples dataset for different tasks by calling self.set_task().
- Parameters:
root (
str
) – root directory of the raw data (should contain many csv files).tables (
List
[str
]) – list of tables to be loaded (e.g., [“DIAGNOSES_ICD”, “PROCEDURES_ICD”]). Basic tables will be loaded by default.code_mapping (
Optional
[Dict
[str
,Union
[str
,Tuple
[str
,Dict
]]]]) –a dictionary containing the code mapping information. The key is a str of the source code vocabulary and the value is of two formats:
a str of the target code vocabulary. E.g., {“NDC”, “ATC”}.
- a tuple with two elements. The first element is a str of the
target code vocabulary and the second element is a dict with keys “source_kwargs” or “target_kwargs” and values of the corresponding kwargs for the CrossMap.map() method. E.g., {“NDC”, (“ATC”, {“target_kwargs”: {“level”: 3}})}.
Default is empty dict, which means the original code will be used.
dev (
bool
) – whether to enable dev mode (only use a small subset of the data). Default is False.refresh_cache (
bool
) – whether to refresh the cache; if true, the dataset will be processed from scratch and the cache will be updated. Default is False.
- parse_tables()[source]#
Parses the tables in self.tables and return a dict of patients.
- Will be called in self.__init__() if cache file does not exist or
refresh_cache is True.
This function will first call self.parse_basic_info() to parse the basic patient information, and then call self.parse_[table_name]() to parse the table with name table_name. Both self.parse_basic_info() and self.parse_[table_name]() should be implemented in the subclass.
- set_task(task_fn, task_name=None)[source]#
Processes the base dataset to generate the task-specific sample dataset.
This function should be called by the user after the base dataset is initialized. It will iterate through all patients in the base dataset and call task_fn which should be implemented by the specific task.
- Parameters:
task_fn (
Callable
) – a function that takes a single patient and returns a list of samples (each sample is a dict with patient_id, visit_id, and other task-specific attributes as key). The samples will be concatenated to form the sample dataset.task_name (
Optional
[str
]) – the name of the task. If None, the name of the task function will be used.
- Returns:
the task-specific sample dataset.
- Return type:
sample_dataset
Note
- In task_fn, a patient may be converted to multiple samples, e.g.,
a patient with three visits may be converted to three samples ([visit 1], [visit 1, visit 2], [visit 1, visit 2, visit 3]). Patients can also be excluded from the task dataset by returning an empty list.