Source code for pyhealth.tasks.generate_ehr

"""EHR sequence-generation tasks for PyHealth generative models.

This is the shared task for every generator in
:mod:`pyhealth.models.generators` (HALO, MedGAN, CorGAN, PromptEHR, ...). It
extracts, for each patient, the ordered list of visits where each visit is the
list of medical codes recorded in that admission. The single input feature
``visits`` is processed by :class:`~pyhealth.processors.NestedSequenceProcessor`;
there is no prediction label, so ``output_schema`` is empty.

:class:`EHRGeneration` holds all the extraction logic; dataset-specific
subclasses only declare which event type and code attribute to read.

Evaluating generated data
-------------------------
The privacy/utility metrics in :mod:`pyhealth.metrics.generative` (``utils.py``,
``privacy.py``, ``utility.py`` -- exposed through ``evaluate_synthetic_ehr``)
consume **long-form** dataframes: one row per ``(patient, visit, code)`` with
columns ``id`` / ``time`` / ``visit_codes`` / ``labels``. ``id`` is the patient
identifier, ``time`` the (integer) visit index, ``visit_codes`` a single code
string, and ``labels`` a patient-level binary label (reduced via ``max`` over
the patient's rows).

Both the real task samples and a generator's ``generate()`` output use the same
``{"visits": [[code, ...], ...]}`` record shape, so
:func:`to_evaluation_dataframe` converts either into that long-form table. A
processed ``SampleDataset`` can be turned back into records with
:func:`decode_dataset`. Subjects are renumbered sequentially (0, 1, 2, ...) in
the ``id`` column -- synthetic patients do not correspond to real ones, so any
original ``patient_id`` is ignored.

.. code-block:: python

    from pyhealth.tasks.generate_ehr import decode_dataset, to_evaluation_dataframe
    from pyhealth.metrics.generative import evaluate_synthetic_ehr

    # Real train/test EHR come from the processed SampleDataset(s):
    train_df = to_evaluation_dataframe(decode_dataset(train_dataset))
    test_df = to_evaluation_dataframe(decode_dataset(test_dataset))

    # Synthetic EHR comes straight from the trained generator (HALO, GPT2, ...):
    synthetic = model.generate(num_samples=len(train_dataset))
    syn_df = to_evaluation_dataframe(synthetic)

    # Privacy metrics need no labels:
    results = evaluate_synthetic_ehr(train_df, test_df, syn_df, metrics="privacy")

The **utility** metrics (machine-learning efficacy, next-visit prediction)
additionally require a meaningful binary ``labels`` column. Since this task is
unconditional (no labels), pass a ``label_fn`` to derive one per patient -- e.g.
``label_fn=lambda r: any("250" in c for v in r["visits"] for c in v)`` for a
diabetes flag -- and the same ``label_fn`` must be applied to the real and
synthetic frames. With no label available, restrict to ``metrics="privacy"``.

Note:
    The MLE component currently hard-codes the downstream task to
    next-visit prediction, which is degenerate for bag-of-codes
    generators (MedGAN, CorGAN) that emit a single aggregate visit per
    patient. A future revision will let callers plug in static-label
    tasks (e.g. mortality, readmission, "ever diagnosed with X") so MLE
    is meaningful for both sequential (HALO, GPT2, PromptEHR) and
    bag-of-codes generators. Until then, restrict bag-of-codes
    evaluation to ``metrics="privacy"`` plus the prevalence metrics.
"""

import logging
from typing import Callable, Dict, List, Optional, Type, Union

from pyhealth.data.data import Patient
from pyhealth.processors import NestedSequenceProcessor

from .base_task import BaseTask

logger = logging.getLogger(__name__)


[docs]class EHRGeneration(BaseTask): """Generic per-visit code-sequence task for unconditional EHR generators. Builds one sample per qualifying patient: the ordered list of visits, each visit being the list of codes (read from ``code_attr`` on ``event_type`` events) recorded in that admission. Patients with fewer than ``min_visits`` qualifying visits are skipped. Subclass and override the class attributes for a specific dataset, or set them on an instance. The defaults read MIMIC-III ICD-9 diagnosis codes. Args: task_name: Name of the task. input_schema: ``{"visits": NestedSequenceProcessor}``. output_schema: empty (generative task, no labels). event_type: Event type to pull per admission. Default ``"diagnoses_icd"``. code_attr: Event attribute holding the code string. Default ``"icd9_code"``. min_visits: Minimum qualifying visits to keep a patient. Default 2. """ task_name: str = "ehr_generation" input_schema: Dict[str, Union[str, Type]] = {"visits": NestedSequenceProcessor} output_schema: Dict[str, Union[str, Type]] = {} event_type: str = "diagnoses_icd" code_attr: str = "icd9_code" min_visits: int = 2 def __call__(self, patient: Patient) -> List[Dict]: """Extract the per-visit code sequence for a patient.""" visits: List[List[str]] = [] admissions = patient.get_events(event_type="admissions") for admission in admissions: events = patient.get_events( event_type=self.event_type, filters=[("hadm_id", "==", admission.hadm_id)], ) codes = [ getattr(event, self.code_attr) for event in events if getattr(event, self.code_attr, None) ] if codes: visits.append(codes) if len(visits) < self.min_visits: return [] return [{"patient_id": patient.patient_id, "visits": visits}]
[docs]class EHRGenerationMIMIC3(EHRGeneration): """EHR generation task for MIMIC-III (ICD-9 diagnosis codes). Examples: >>> from pyhealth.datasets import MIMIC3Dataset >>> from pyhealth.tasks import EHRGenerationMIMIC3 >>> dataset = MIMIC3Dataset( ... root="/path/to/mimic-iii/1.4", ... tables=["diagnoses_icd"], ... ) >>> samples = dataset.set_task(EHRGenerationMIMIC3()) """ task_name: str = "ehr_generation_mimic3" event_type: str = "diagnoses_icd" code_attr: str = "icd9_code"
[docs]class EHRGenerationMIMIC4(EHRGeneration): """EHR generation task for MIMIC-IV (ICD diagnosis codes). Examples: >>> from pyhealth.datasets import MIMIC4Dataset >>> from pyhealth.tasks import EHRGenerationMIMIC4 >>> dataset = MIMIC4Dataset( ... ehr_root="/path/to/mimiciv/2.2/", ... ehr_tables=["patients", "admissions", "diagnoses_icd"], ... ) >>> samples = dataset.set_task(EHRGenerationMIMIC4()) """ task_name: str = "ehr_generation_mimic4" event_type: str = "diagnoses_icd" code_attr: str = "icd_code"
# ---------------------------------------------------------------------------- # Conversion helpers for pyhealth.metrics.generative.evaluate_synthetic_ehr # ----------------------------------------------------------------------------
[docs]def to_evaluation_dataframe( records, label_fn: Optional[Callable[[Dict], int]] = None, subject_col: str = "id", visit_col: str = "time", code_col: str = "visit_codes", label_col: str = "labels", ): """Flatten EHR-generation records into the long-form evaluation dataframe. Produces the one-row-per-``(patient, visit, code)`` table consumed by :func:`pyhealth.metrics.generative.evaluate_synthetic_ehr` (and the ``utils.py`` / ``privacy.py`` / ``utility.py`` functions beneath it). Subjects are numbered **sequentially** (0, 1, 2, ...) in ``subject_col``; any ``"patient_id"`` on the records is ignored, since synthetic patients do not correspond to real ones. Args: records: Iterable of ``{"visits": [[code, ...], ...]}`` dicts. Both the :class:`EHRGeneration` task output and a generator's ``generate()`` output have this shape. label_fn: Optional callable mapping a record to a binary patient label (0/1) used by the utility metrics. Defaults to all-zeros. subject_col: Output patient-id column. Default ``"id"``. visit_col: Output visit-index column. Default ``"time"``. code_col: Output single-code column. Default ``"visit_codes"``. label_col: Output binary-label column. Default ``"labels"``. Returns: ``pandas.DataFrame`` with columns ``[subject_col, visit_col, code_col, label_col]``. """ import pandas as pd rows = [] for subject_id, record in enumerate(records): label = 0 if label_fn is None else int(label_fn(record)) for visit_idx, visit in enumerate(record["visits"]): for code in visit: rows.append( { subject_col: subject_id, visit_col: visit_idx, code_col: code, label_col: label, } ) return pd.DataFrame( rows, columns=[subject_col, visit_col, code_col, label_col] )
[docs]def decode_dataset(sample_dataset, feature_key: str = "visits") -> List[Dict]: """Decode a processed EHRGeneration ``SampleDataset`` back into records. Inverts the :class:`~pyhealth.processors.NestedSequenceProcessor` encoding using its vocabulary (skipping ``<pad>``/``<unk>``), yielding one ``{"visits": [[code_str, ...], ...]}`` record per sample. Use this to build the real train/test frames that ``evaluate_synthetic_ehr`` compares against. Args: sample_dataset: A ``SampleDataset`` produced by :class:`EHRGeneration`. feature_key: Input feature key holding the nested code sequence. Default ``"visits"``. Returns: List of ``{"visits": [[code_str, ...], ...]}`` records. """ processor = sample_dataset.input_processors[feature_key] index_to_code = {idx: code for code, idx in processor.code_vocab.items()} records: List[Dict] = [] for i in range(len(sample_dataset)): sample = sample_dataset[i] visits: List[List[str]] = [] for row in sample[feature_key].tolist(): codes = [ index_to_code[int(idx)] for idx in row if index_to_code.get(int(idx)) not in (None, "<pad>", "<unk>") ] if codes: visits.append(codes) records.append({"visits": visits}) return records