Source code for pyhealth.metrics.generative

"""Evaluation metrics for synthetic (generative) EHR data.

This subpackage provides metrics for assessing synthetic electronic health
record (EHR) data along three axes:

    - **Privacy** (:mod:`pyhealth.metrics.generative.privacy`): NNAAR,
      membership inference, and discriminator-based adversarial accuracy.
    - **Utility / fidelity** (:mod:`pyhealth.metrics.generative.utility`):
      machine learning efficacy (TRTR vs TSTR) and code-prevalence similarity.

The convenience function :func:`evaluate_synthetic_ehr` runs the full suite
and returns a single merged dictionary of ``{metric_name: (mean, std)}``.

Input format:
    Every metric consumes plain pandas dataframes in *flat / long* format --
    **one row per (patient, visit, code) event** -- so the logic stays easy to
    inspect. By default each dataframe has four columns
    ``[id, time, visit_codes, labels]`` (override the names via the
    ``subject_col`` / ``visit_col`` / ``code_col`` / ``label_col`` arguments):

        - ``id`` (``subject_col``): patient identifier. Any hashable value;
          commonly ``str`` or ``int``.
        - ``time`` (``visit_col``): visit index / timestep. Sortable, usually
          ``int``; visits are ordered per patient by this column.
        - ``visit_codes`` (``code_col``): a **single** medical code for this
          row (``str`` or ``int``). One code per row -- a visit containing *k*
          codes spans *k* rows. Cells are scalars, **not** lists or arrays.
        - ``labels`` (``label_col``): per-patient binary label (0/1, ``int``).

    The real (``train_ehr``, ``test_ehr``) and synthetic (``syn_ehr``)
    dataframes must all share this same schema. ``labels`` is ignored by the
    privacy metrics and is overwritten internally by the utility metrics, but
    is required so every dataframe has a uniform schema.

Why dataframes (and not plain ``List[...]``)?
    The flat dataframe is purely the *interchange* format -- a single, uniform
    interface shared by every metric and produced once by
    :func:`pyhealth.tasks.to_evaluation_dataframe`. Internally each family uses
    whatever representation is most natural:

        - **Privacy** metrics immediately reduce the frame to a nested
          ``List[List[set]]`` (sequence of per-visit code sets) via
          ``convert_visits_to_sets`` and do all distance work on plain Python
          lists -- no pandas in the hot loop.
        - **Utility / fidelity** metrics genuinely benefit from the dataframe:
          code-prevalence uses ``groupby(...).nunique()``, MLE builds the
          next-visit-prediction supervision with grouped per-patient label
          assignment, and the discriminator metric concatenates / filters /
          relabels real-vs-synthetic rows. Re-implementing these on raw lists
          would be more code for no gain.

    So the long-form frame keeps the public API consistent and the heavy
    transforms readable, while the per-metric internals are free to drop down
    to lists where that is simpler.

Note:
    The MLE (utility) component is currently hard-coded to next-visit
    prediction and is therefore only meaningful for sequential generators
    (HALO, GPT2, PromptEHR). It will be expanded to support pluggable
    downstream tasks so that bag-of-codes generators (MedGAN, CorGAN) can
    be evaluated with a static-label task (e.g. mortality, readmission).
    Until then, prefer the privacy and prevalence metrics when evaluating
    MedGAN/CorGAN output.
"""

import logging
from typing import Dict, Optional, Tuple

import pandas as pd

from .privacy import (
    calc_membership_inference,
    calc_nnaar,
    compute_discriminator_privacy,
)
from .utility import compute_mle, compute_prevalence_metrics
from .utils import train_lstm_model, train_sklearn_model

logger = logging.getLogger(__name__)

__all__ = [
    "calc_nnaar",
    "calc_membership_inference",
    "compute_discriminator_privacy",
    "compute_mle",
    "compute_prevalence_metrics",
    "evaluate_synthetic_ehr",
]


[docs]def evaluate_synthetic_ehr( train_ehr: pd.DataFrame, test_ehr: pd.DataFrame, syn_ehr: pd.DataFrame, subject_col: str = "id", visit_col: str = "time", code_col: str = "visit_codes", label_col: str = "labels", sample_size: int = 1000, mode: str = "lstm", metrics: str = "all", lstm_params: Optional[Dict] = None, sklearn_params: Optional[Dict] = None, n_bootstraps: int = 100, n_runs: int = 5, ) -> Dict[str, Tuple[float, float]]: """Runs the full synthetic-EHR evaluation suite. Computes privacy and/or utility metrics comparing synthetic EHR data against real train/test data, and returns a single merged dictionary. All three dataframes are flat / long-format (one row per ``(patient, visit, code)`` event) and must share the same schema. See the module docstring (:mod:`pyhealth.metrics.generative`) for the full column contract, and the example below for how to build them. Args: train_ehr: Real training EHR dataframe, flat ``[id, time, visit_codes, labels]`` format. test_ehr: Real held-out test EHR dataframe; same schema as ``train_ehr``. syn_ehr: Synthetic EHR dataframe; same schema as ``train_ehr``. subject_col: Column name for patient/subject identifiers. visit_col: Column name for visit/timestep identifiers. code_col: Column name for the medical codes (one code per row). label_col: Column name for the per-patient binary label. sample_size: Number of patients sampled per dataset for the privacy metrics. mode: Predictive backbone for the utility metrics; ``"lstm"`` uses the built-in LSTM classifier, ``"rf"`` uses a random forest. metrics: Which metric group to compute: ``"all"``, ``"privacy"`` or ``"utility"``. lstm_params: Optional overrides for the LSTM (``embed_dim``, ``hidden_dim``, ``batch_size``, ``epochs``). sklearn_params: Optional overrides for the sklearn model (``model``). n_bootstraps: Number of bootstrap resamples for the utility metrics. n_runs: Number of sampling runs for the privacy metrics. Returns: Dictionary mapping each metric name to a ``(mean, std)`` tuple. Raises: ValueError: If ``metrics`` or ``mode`` is not a recognized value. Examples: The inputs are flat / long-format dataframes -- one row per ``(patient, visit, code)`` event -- with four columns by default: - ``id``: patient identifier (any hashable; ``str`` or ``int``). - ``time``: visit index / timestep (sortable; usually ``int``). - ``visit_codes``: a single medical code for this row (``str`` or ``int``). One code per row -- a visit with *k* codes spans *k* rows; cells are scalars, not lists/arrays. - ``labels``: per-patient binary label (0/1, ``int``). ``train_ehr``, ``test_ehr`` and ``syn_ehr`` must all share this schema. >>> import pandas as pd >>> from pyhealth.metrics.generative import evaluate_synthetic_ehr >>> >>> # One row per (patient, visit, code). Patient "p0" has two visits >>> # (time 0 with two codes, time 1 with one code); "p1" has one visit. >>> rows = [ ... {"id": "p0", "time": 0, "visit_codes": "428.0", "labels": 0}, ... {"id": "p0", "time": 0, "visit_codes": "250.00", "labels": 0}, ... {"id": "p0", "time": 1, "visit_codes": "401.9", "labels": 0}, ... {"id": "p1", "time": 0, "visit_codes": "428.0", "labels": 0}, ... ] >>> train_ehr = pd.DataFrame(rows) >>> test_ehr = train_ehr.copy() # same schema; real held-out patients >>> syn_ehr = train_ehr.copy() # same schema; generator output >>> >>> results = evaluate_synthetic_ehr( ... train_ehr, test_ehr, syn_ehr, metrics="privacy", sample_size=2 ... ) >>> nnaar_mean, nnaar_std = results["nnaar"] >>> >>> # Custom column names: pass *_col to match your dataframe. >>> results = evaluate_synthetic_ehr( ... train_ehr, test_ehr, syn_ehr, ... subject_col="id", visit_col="time", ... code_col="visit_codes", label_col="labels", ... ) """ if metrics not in ("all", "privacy", "utility"): raise ValueError( f"Unknown metrics group: {metrics!r}. " "Expected 'all', 'privacy' or 'utility'." ) if mode not in ("lstm", "rf"): raise ValueError(f"Unknown mode: {mode!r}. Expected 'lstm' or 'rf'.") lstm_params = lstm_params or {} sklearn_params = sklearn_params or {} final_output: Dict[str, Tuple[float, float]] = {} if metrics in ("all", "privacy"): final_output.update( calc_nnaar( train_ehr, test_ehr, syn_ehr, subject_col=subject_col, visit_col=visit_col, code_col=code_col, label_col=label_col, sample_size=sample_size, n_runs=n_runs, ) ) final_output.update( calc_membership_inference( train_ehr, test_ehr, syn_ehr, subject_col=subject_col, visit_col=visit_col, code_col=code_col, label_col=label_col, num_attack_samples=sample_size, n_runs=n_runs, ) ) if metrics in ("all", "utility"): if mode == "lstm": train_fn = train_lstm_model train_kwargs = { "embed_dim": lstm_params.get("embed_dim", 32), "hidden_dim": lstm_params.get("hidden_dim", 32), "batch_size": lstm_params.get("batch_size", 32), "epochs": lstm_params.get("epochs", 5), "verbose": False, } else: train_fn = train_sklearn_model train_kwargs = {"model": sklearn_params.get("model", "rf")} final_output.update( compute_mle( train_fn=train_fn, train_ehr=train_ehr, test_ehr=test_ehr, syn_ehr=syn_ehr, subject_col=subject_col, visit_col=visit_col, code_col=code_col, label_col=label_col, n_bootstraps=n_bootstraps, **train_kwargs, ) ) final_output.update( compute_discriminator_privacy( train_fn=train_fn, train_ehr=train_ehr, test_ehr=test_ehr, syn_ehr=syn_ehr, subject_col=subject_col, visit_col=visit_col, code_col=code_col, label_col=label_col, n_bootstraps=n_bootstraps, **train_kwargs, ) ) final_output.update( compute_prevalence_metrics( train_ehr, syn_ehr, subject_col=subject_col, code_col=code_col, n_bootstraps=n_bootstraps, ) ) return final_output