pyhealth.tasks.generate_ehr#
Task that turns a longitudinal EHR dataset into per-patient, per-visit code
sequences for training unconditional synthetic-EHR generators (HALO, GPT2,
PromptEHR, MedGAN, CorGAN), plus helpers to flatten generated output into the
long-form dataframe consumed by pyhealth.metrics.generative.
Task Classes#
- class pyhealth.tasks.generate_ehr.EHRGeneration(code_mapping=None)[source]#
Bases:
BaseTaskGeneric per-visit code-sequence task for unconditional EHR generators.
Builds one sample per qualifying patient: the ordered list of visits, each visit being the list of codes (read from
code_attronevent_typeevents) recorded in that admission. Patients with fewer thanmin_visitsqualifying visits are skipped.Subclass and override the class attributes for a specific dataset, or set them on an instance. The defaults read MIMIC-III ICD-9 diagnosis codes.
- Parameters:
task_name – Name of the task.
input_schema –
{"visits": NestedSequenceProcessor}.output_schema – empty (generative task, no labels).
event_type – Event type to pull per admission. Default
"diagnoses_icd".code_attr – Event attribute holding the code string. Default
"icd9_code".min_visits – Minimum qualifying visits to keep a patient. Default 2.
- input_schema: Dict[str, Union[str, Type]] = {'visits': <class 'pyhealth.processors.nested_sequence_processor.NestedSequenceProcessor'>}#
- pre_filter(df)#
- Return type:
LazyFrame
- class pyhealth.tasks.generate_ehr.EHRGenerationMIMIC3(code_mapping=None)[source]#
Bases:
EHRGenerationEHR generation task for MIMIC-III (ICD-9 diagnosis codes).
Examples
>>> from pyhealth.datasets import MIMIC3Dataset >>> from pyhealth.tasks import EHRGenerationMIMIC3 >>> dataset = MIMIC3Dataset( ... root="/path/to/mimic-iii/1.4", ... tables=["diagnoses_icd"], ... ) >>> samples = dataset.set_task(EHRGenerationMIMIC3())
- input_schema: Dict[str, Union[str, Type]] = {'visits': <class 'pyhealth.processors.nested_sequence_processor.NestedSequenceProcessor'>}#
- pre_filter(df)#
- Return type:
LazyFrame
- class pyhealth.tasks.generate_ehr.EHRGenerationMIMIC4(code_mapping=None)[source]#
Bases:
EHRGenerationEHR generation task for MIMIC-IV (ICD diagnosis codes).
Examples
>>> from pyhealth.datasets import MIMIC4Dataset >>> from pyhealth.tasks import EHRGenerationMIMIC4 >>> dataset = MIMIC4Dataset( ... ehr_root="/path/to/mimiciv/2.2/", ... ehr_tables=["patients", "admissions", "diagnoses_icd"], ... ) >>> samples = dataset.set_task(EHRGenerationMIMIC4())
- input_schema: Dict[str, Union[str, Type]] = {'visits': <class 'pyhealth.processors.nested_sequence_processor.NestedSequenceProcessor'>}#
- pre_filter(df)#
- Return type:
LazyFrame
Helper Functions#
- pyhealth.tasks.generate_ehr.decode_dataset(sample_dataset, feature_key='visits')[source]#
Decode a processed EHRGeneration
SampleDatasetback into records.Inverts the
NestedSequenceProcessorencoding using its vocabulary (skipping<pad>/<unk>), yielding one{"visits": [[code_str, ...], ...]}record per sample. Use this to build the real train/test frames thatevaluate_synthetic_ehrcompares against.- Parameters:
sample_dataset – A
SampleDatasetproduced byEHRGeneration.feature_key (
str) – Input feature key holding the nested code sequence. Default"visits".
- Return type:
- Returns:
List of
{"visits": [[code_str, ...], ...]}records.
- pyhealth.tasks.generate_ehr.to_evaluation_dataframe(records, label_fn=None, subject_col='id', visit_col='time', code_col='visit_codes', label_col='labels')[source]#
Flatten EHR-generation records into the long-form evaluation dataframe.
Produces the one-row-per-
(patient, visit, code)table consumed bypyhealth.metrics.generative.evaluate_synthetic_ehr()(and theutils.py/privacy.py/utility.pyfunctions beneath it).Subjects are numbered sequentially (0, 1, 2, …) in
subject_col; any"patient_id"on the records is ignored, since synthetic patients do not correspond to real ones.- Parameters:
records – Iterable of
{"visits": [[code, ...], ...]}dicts. Both theEHRGenerationtask output and a generator’sgenerate()output have this shape.label_fn (
Optional[Callable[[Dict],int]]) – Optional callable mapping a record to a binary patient label (0/1) used by the utility metrics. Defaults to all-zeros.subject_col (
str) – Output patient-id column. Default"id".visit_col (
str) – Output visit-index column. Default"time".code_col (
str) – Output single-code column. Default"visit_codes".label_col (
str) – Output binary-label column. Default"labels".
- Returns:
pandas.DataFramewith columns[subject_col, visit_col, code_col, label_col].