pyhealth.datasets.BaseEHRDataset#

This is the basic EHR dataset class. Any specific EHR dataset will inherit from this class.

class pyhealth.datasets.BaseEHRDataset(root, tables, dataset_name=None, code_mapping=None, dev=False, refresh_cache=False)[source]#

Bases: ABC

Abstract base dataset class.

This abstract class defines a uniform interface for all EHR datasets (e.g., MIMIC-III, MIMIC-IV, eICU, OMOP).

Each specific dataset will be a subclass of this abstract class, which can then be converted to samples dataset for different tasks by calling self.set_task().

Parameters:
  • dataset_name (Optional[str]) – name of the dataset.

  • root (str) – root directory of the raw data (should contain many csv files).

  • tables (List[str]) – list of tables to be loaded (e.g., [“DIAGNOSES_ICD”, “PROCEDURES_ICD”]). Basic tables will be loaded by default.

  • code_mapping (Optional[Dict[str, Union[str, Tuple[str, Dict]]]]) –

    a dictionary containing the code mapping information. The key is a str of the source code vocabulary and the value is of two formats:

    • a str of the target code vocabulary. E.g., {“NDC”, “ATC”}.

    • a tuple with two elements. The first element is a str of the

      target code vocabulary and the second element is a dict with keys “source_kwargs” or “target_kwargs” and values of the corresponding kwargs for the CrossMap.map() method. E.g., {“NDC”, (“ATC”, {“target_kwargs”: {“level”: 3}})}.

    Default is empty dict, which means the original code will be used.

  • dev (bool) – whether to enable dev mode (only use a small subset of the data). Default is False.

  • refresh_cache (bool) – whether to refresh the cache; if true, the dataset will be processed from scratch and the cache will be updated. Default is False.

parse_tables()[source]#

Parses the tables in self.tables and return a dict of patients.

Will be called in self.__init__() if cache file does not exist or

refresh_cache is True.

This function will first call self.parse_basic_info() to parse the basic patient information, and then call self.parse_[table_name]() to parse the table with name table_name. Both self.parse_basic_info() and self.parse_[table_name]() should be implemented in the subclass.

Return type:

Dict[str, Patient]

Returns:

A dict mapping patient_id to Patient object.

property available_tables: List[str]#

Returns a list of available tables for the dataset.

Return type:

List[str]

Returns:

List of available tables.

stat()[source]#

Returns some statistics of the base dataset.

Return type:

str

static info()[source]#

Prints the output format.

set_task(task_fn, task_name=None)[source]#

Processes the base dataset to generate the task-specific sample dataset.

This function should be called by the user after the base dataset is initialized. It will iterate through all patients in the base dataset and call task_fn which should be implemented by the specific task.

Parameters:
  • task_fn (Callable) – a function that takes a single patient and returns a list of samples (each sample is a dict with patient_id, visit_id, and other task-specific attributes as key). The samples will be concatenated to form the sample dataset.

  • task_name (Optional[str]) – the name of the task. If None, the name of the task function will be used.

Returns:

the task-specific sample dataset.

Return type:

sample_dataset

Note

In task_fn, a patient may be converted to multiple samples, e.g.,

a patient with three visits may be converted to three samples ([visit 1], [visit 1, visit 2], [visit 1, visit 2, visit 3]). Patients can also be excluded from the task dataset by returning an empty list.