pyhealth.tasks.generate_ehr#

Task that turns a longitudinal EHR dataset into per-patient, per-visit code sequences for training unconditional synthetic-EHR generators (HALO, GPT2, PromptEHR, MedGAN, CorGAN), plus helpers to flatten generated output into the long-form dataframe consumed by pyhealth.metrics.generative.

Task Classes#

class pyhealth.tasks.generate_ehr.EHRGeneration(code_mapping=None)[source]#

Bases: BaseTask

Generic per-visit code-sequence task for unconditional EHR generators.

Builds one sample per qualifying patient: the ordered list of visits, each visit being the list of codes (read from code_attr on event_type events) recorded in that admission. Patients with fewer than min_visits qualifying visits are skipped.

Subclass and override the class attributes for a specific dataset, or set them on an instance. The defaults read MIMIC-III ICD-9 diagnosis codes.

Parameters:
  • task_name – Name of the task.

  • input_schema{"visits": NestedSequenceProcessor}.

  • output_schema – empty (generative task, no labels).

  • event_type – Event type to pull per admission. Default "diagnoses_icd".

  • code_attr – Event attribute holding the code string. Default "icd9_code".

  • min_visits – Minimum qualifying visits to keep a patient. Default 2.

task_name: str = 'ehr_generation'#
input_schema: Dict[str, Union[str, Type]] = {'visits': <class 'pyhealth.processors.nested_sequence_processor.NestedSequenceProcessor'>}#
output_schema: Dict[str, Union[str, Type]] = {}#
event_type: str = 'diagnoses_icd'#
code_attr: str = 'icd9_code'#
min_visits: int = 2#
pre_filter(df)#
Return type:

LazyFrame

class pyhealth.tasks.generate_ehr.EHRGenerationMIMIC3(code_mapping=None)[source]#

Bases: EHRGeneration

EHR generation task for MIMIC-III (ICD-9 diagnosis codes).

Examples

>>> from pyhealth.datasets import MIMIC3Dataset
>>> from pyhealth.tasks import EHRGenerationMIMIC3
>>> dataset = MIMIC3Dataset(
...     root="/path/to/mimic-iii/1.4",
...     tables=["diagnoses_icd"],
... )
>>> samples = dataset.set_task(EHRGenerationMIMIC3())
task_name: str = 'ehr_generation_mimic3'#
event_type: str = 'diagnoses_icd'#
code_attr: str = 'icd9_code'#
input_schema: Dict[str, Union[str, Type]] = {'visits': <class 'pyhealth.processors.nested_sequence_processor.NestedSequenceProcessor'>}#
min_visits: int = 2#
output_schema: Dict[str, Union[str, Type]] = {}#
pre_filter(df)#
Return type:

LazyFrame

class pyhealth.tasks.generate_ehr.EHRGenerationMIMIC4(code_mapping=None)[source]#

Bases: EHRGeneration

EHR generation task for MIMIC-IV (ICD diagnosis codes).

Examples

>>> from pyhealth.datasets import MIMIC4Dataset
>>> from pyhealth.tasks import EHRGenerationMIMIC4
>>> dataset = MIMIC4Dataset(
...     ehr_root="/path/to/mimiciv/2.2/",
...     ehr_tables=["patients", "admissions", "diagnoses_icd"],
... )
>>> samples = dataset.set_task(EHRGenerationMIMIC4())
task_name: str = 'ehr_generation_mimic4'#
event_type: str = 'diagnoses_icd'#
code_attr: str = 'icd_code'#
input_schema: Dict[str, Union[str, Type]] = {'visits': <class 'pyhealth.processors.nested_sequence_processor.NestedSequenceProcessor'>}#
min_visits: int = 2#
output_schema: Dict[str, Union[str, Type]] = {}#
pre_filter(df)#
Return type:

LazyFrame

Helper Functions#

pyhealth.tasks.generate_ehr.decode_dataset(sample_dataset, feature_key='visits')[source]#

Decode a processed EHRGeneration SampleDataset back into records.

Inverts the NestedSequenceProcessor encoding using its vocabulary (skipping <pad>/<unk>), yielding one {"visits": [[code_str, ...], ...]} record per sample. Use this to build the real train/test frames that evaluate_synthetic_ehr compares against.

Parameters:
  • sample_dataset – A SampleDataset produced by EHRGeneration.

  • feature_key (str) – Input feature key holding the nested code sequence. Default "visits".

Return type:

List[Dict]

Returns:

List of {"visits": [[code_str, ...], ...]} records.

pyhealth.tasks.generate_ehr.to_evaluation_dataframe(records, label_fn=None, subject_col='id', visit_col='time', code_col='visit_codes', label_col='labels')[source]#

Flatten EHR-generation records into the long-form evaluation dataframe.

Produces the one-row-per-(patient, visit, code) table consumed by pyhealth.metrics.generative.evaluate_synthetic_ehr() (and the utils.py / privacy.py / utility.py functions beneath it).

Subjects are numbered sequentially (0, 1, 2, …) in subject_col; any "patient_id" on the records is ignored, since synthetic patients do not correspond to real ones.

Parameters:
  • records – Iterable of {"visits": [[code, ...], ...]} dicts. Both the EHRGeneration task output and a generator’s generate() output have this shape.

  • label_fn (Optional[Callable[[Dict], int]]) – Optional callable mapping a record to a binary patient label (0/1) used by the utility metrics. Defaults to all-zeros.

  • subject_col (str) – Output patient-id column. Default "id".

  • visit_col (str) – Output visit-index column. Default "time".

  • code_col (str) – Output single-code column. Default "visit_codes".

  • label_col (str) – Output binary-label column. Default "labels".

Returns:

pandas.DataFrame with columns [subject_col, visit_col, code_col, label_col].