pyhealth.metrics.generative#
Evaluation metrics for synthetic (generative) EHR data, covering privacy, utility, and statistical fidelity.
- pyhealth.metrics.generative.evaluate_synthetic_ehr(train_ehr, test_ehr, syn_ehr, subject_col='id', visit_col='time', code_col='visit_codes', label_col='labels', sample_size=1000, mode='lstm', metrics='all', lstm_params=None, sklearn_params=None, n_bootstraps=100, n_runs=5)[source]#
Runs the full synthetic-EHR evaluation suite.
Computes privacy and/or utility metrics comparing synthetic EHR data against real train/test data, and returns a single merged dictionary.
All three dataframes are flat / long-format (one row per
(patient, visit, code)event) and must share the same schema. See the module docstring (pyhealth.metrics.generative) for the full column contract, and the example below for how to build them.- Parameters:
train_ehr (
DataFrame) – Real training EHR dataframe, flat[id, time, visit_codes, labels]format.test_ehr (
DataFrame) – Real held-out test EHR dataframe; same schema astrain_ehr.syn_ehr (
DataFrame) – Synthetic EHR dataframe; same schema astrain_ehr.subject_col (
str) – Column name for patient/subject identifiers.visit_col (
str) – Column name for visit/timestep identifiers.code_col (
str) – Column name for the medical codes (one code per row).label_col (
str) – Column name for the per-patient binary label.sample_size (
int) – Number of patients sampled per dataset for the privacy metrics.mode (
str) – Predictive backbone for the utility metrics;"lstm"uses the built-in LSTM classifier,"rf"uses a random forest.metrics (
str) – Which metric group to compute:"all","privacy"or"utility".lstm_params (
Optional[Dict]) – Optional overrides for the LSTM (embed_dim,hidden_dim,batch_size,epochs).sklearn_params (
Optional[Dict]) – Optional overrides for the sklearn model (model).n_bootstraps (
int) – Number of bootstrap resamples for the utility metrics.n_runs (
int) – Number of sampling runs for the privacy metrics.
- Return type:
- Returns:
Dictionary mapping each metric name to a
(mean, std)tuple.- Raises:
ValueError – If
metricsormodeis not a recognized value.
Examples
The inputs are flat / long-format dataframes – one row per
(patient, visit, code)event – with four columns by default:id: patient identifier (any hashable;strorint).time: visit index / timestep (sortable; usuallyint).visit_codes: a single medical code for this row (strorint). One code per row – a visit with k codes spans k rows; cells are scalars, not lists/arrays.labels: per-patient binary label (0/1,int).
train_ehr,test_ehrandsyn_ehrmust all share this schema.>>> import pandas as pd >>> from pyhealth.metrics.generative import evaluate_synthetic_ehr >>> >>> # One row per (patient, visit, code). Patient "p0" has two visits >>> # (time 0 with two codes, time 1 with one code); "p1" has one visit. >>> rows = [ ... {"id": "p0", "time": 0, "visit_codes": "428.0", "labels": 0}, ... {"id": "p0", "time": 0, "visit_codes": "250.00", "labels": 0}, ... {"id": "p0", "time": 1, "visit_codes": "401.9", "labels": 0}, ... {"id": "p1", "time": 0, "visit_codes": "428.0", "labels": 0}, ... ] >>> train_ehr = pd.DataFrame(rows) >>> test_ehr = train_ehr.copy() # same schema; real held-out patients >>> syn_ehr = train_ehr.copy() # same schema; generator output >>> >>> results = evaluate_synthetic_ehr( ... train_ehr, test_ehr, syn_ehr, metrics="privacy", sample_size=2 ... ) >>> nnaar_mean, nnaar_std = results["nnaar"] >>> >>> # Custom column names: pass *_col to match your dataframe. >>> results = evaluate_synthetic_ehr( ... train_ehr, test_ehr, syn_ehr, ... subject_col="id", visit_col="time", ... code_col="visit_codes", label_col="labels", ... )
Privacy metrics#
- pyhealth.metrics.generative.calc_nnaar(train_ehr, test_ehr, syn_ehr, subject_col='id', visit_col='time', code_col='visit_codes', label_col='labels', sample_size=1000, n_runs=5, verbose=False)[source]#
Computes the Nearest Neighbor Adversarial Accuracy Risk (NNAAR).
NNAAR measures whether the synthetic data sits closer to the real training data than to held-out test data, which would indicate memorization:
NNAAR = AA_ES - AA_TS
where
AA_ESis the adversarial accuracy between test and synthetic data andAA_TSis the adversarial accuracy between train and synthetic data. Values near 0 indicate low privacy risk.All three dataframes are flat
[id, time, visit_codes, labels]frames sharing the same schema (seepyhealth.metrics.generative).- Parameters:
train_ehr (
DataFrame) – Real training EHR dataframe, flat[id, time, visit_codes, labels]format.test_ehr (
DataFrame) – Real held-out test EHR dataframe; same schema astrain_ehr.syn_ehr (
DataFrame) – Synthetic EHR dataframe; same schema astrain_ehr.subject_col (
str) – Column name for patient/subject identifiers.visit_col (
str) – Column name for visit/timestep identifiers.code_col (
str) – Column name for the medical codes (one code per row).label_col (
str) – Column name for the label (unused, kept for a uniform API).sample_size (
int) – Number of patients to sample per dataset per run.n_runs (
int) – Number of independent sampling runs.verbose (
bool) – Whether to show per-run progress bars.
- Return type:
- Returns:
- Dictionary mapping
"nnaar","aa_es"and"aa_ts"to their (mean, std)across runs.
- Dictionary mapping
Examples
>>> from pyhealth.metrics.generative import calc_nnaar >>> # train_ehr, test_ehr, syn_ehr are flat >>> # [id, time, visit_codes, labels] dataframes sharing one schema -- >>> # see evaluate_synthetic_ehr for how to build them. >>> result = calc_nnaar(train_ehr, test_ehr, syn_ehr) >>> nnaar_mean, nnaar_std = result["nnaar"]
- pyhealth.metrics.generative.calc_membership_inference(train_ehr, test_ehr, syn_ehr, subject_col='id', visit_col='time', code_col='visit_codes', label_col='labels', num_attack_samples=1000, n_runs=5, verbose=False)[source]#
Computes Membership Inference Attack (MIA) metrics.
An attacker tries to tell members (training patients) from non-members (test patients) using proximity to the synthetic data: members are expected to be closer to synthetic records. Predictions are made by thresholding the nearest-neighbor distance at its median; F1, precision, recall and accuracy near 0.5 indicate low membership-inference risk.
All three dataframes are flat
[id, time, visit_codes, labels]frames sharing the same schema (seepyhealth.metrics.generative).- Parameters:
train_ehr (
DataFrame) – Real training EHR dataframe (members), flat[id, time, visit_codes, labels]format.test_ehr (
DataFrame) – Real held-out test EHR dataframe (non-members); same schema astrain_ehr.syn_ehr (
DataFrame) – Synthetic EHR dataframe; same schema astrain_ehr.subject_col (
str) – Column name for patient/subject identifiers.visit_col (
str) – Column name for visit/timestep identifiers.code_col (
str) – Column name for the medical codes (one code per row).label_col (
str) – Column name for the label (unused, kept for a uniform API).num_attack_samples (
int) – Total attack-set size (half members, half not).n_runs (
int) – Number of independent sampling runs.verbose (
bool) – Whether to show per-run progress bars.
- Return type:
- Returns:
- Dictionary mapping
"MIA_F1","MIA_Precision","MIA_Recall" and
"MIA_Accuracy"to their(mean, std)across runs.
- Dictionary mapping
Examples
>>> from pyhealth.metrics.generative import calc_membership_inference >>> # train_ehr, test_ehr, syn_ehr are flat >>> # [id, time, visit_codes, labels] dataframes sharing one schema -- >>> # see evaluate_synthetic_ehr for how to build them. >>> result = calc_membership_inference(train_ehr, test_ehr, syn_ehr) >>> f1_mean, f1_std = result["MIA_F1"]
- pyhealth.metrics.generative.compute_discriminator_privacy(train_fn, train_ehr, test_ehr, syn_ehr, subject_col='id', visit_col='time', code_col='visit_codes', label_col='labels', n_bootstraps=5, seed=4, **kwargs)[source]#
Computes a discriminator-based adversarial-accuracy privacy score.
A classifier is trained to predict whether a record is real (1) or synthetic (0). An accuracy near 0.5 means real and synthetic data are indistinguishable (good privacy); accuracy well above 0.5 means the synthetic data is easy to tell apart (poor privacy). The
Privacy_Scorerescales accuracy so 1.0 is perfect privacy and 0.0 is none.- Parameters:
train_fn (
Callable) – A training function such aspyhealth.metrics.generative.utils.train_lstm_model()ortrain_sklearn_model. It must accepttrain_ehr,test_ehr, the four column-name arguments and return(model, y_true, y_pred).train_ehr (
DataFrame) – Real training EHR dataframe, flat[id, time, visit_codes, labels]format.test_ehr (
DataFrame) – Real held-out test EHR dataframe (unused; kept for a uniform API with the other metrics); same schema astrain_ehr.syn_ehr (
DataFrame) – Synthetic EHR dataframe; same schema astrain_ehr.subject_col (
str) – Column name for patient/subject identifiers.visit_col (
str) – Column name for visit/timestep identifiers.code_col (
str) – Column name for the medical codes (one code per row).label_col (
str) – Column name for the original label (unused; the discriminator target replaces it).n_bootstraps (
int) – Number of bootstrap resamples of the predictions.seed (
int) – Random seed for the patient-level train/test split.**kwargs – Extra keyword arguments forwarded to
train_fn.
- Return type:
- Returns:
- Dictionary mapping
"Privacy_Discriminator_Accuracy"and "Privacy_Score"to their(mean, std)across bootstraps.
- Dictionary mapping
Examples
>>> from pyhealth.metrics.generative import compute_discriminator_privacy >>> from pyhealth.metrics.generative.utils import train_lstm_model >>> # train_ehr, test_ehr, syn_ehr are flat >>> # [id, time, visit_codes, labels] dataframes sharing one schema -- >>> # see evaluate_synthetic_ehr for how to build them. >>> result = compute_discriminator_privacy( ... train_lstm_model, train_ehr, test_ehr, syn_ehr ... ) >>> score_mean, score_std = result["Privacy_Score"]
Utility and fidelity metrics#
- pyhealth.metrics.generative.compute_mle(train_fn, train_ehr, test_ehr, syn_ehr, subject_col='id', visit_col='time', code_col='visit_codes', label_col='labels', n_bootstraps=5, **kwargs)[source]#
Computes Machine Learning Efficacy (utility) for synthetic data.
Two classifiers are trained on a next-visit prediction task: one on real training data (Train-Real-Test-Real, TRTR) and one on synthetic data (Train-Synthetic-Test-Real, TSTR). Both are evaluated on the same real test set. Synthetic accuracy/F1 close to real accuracy/F1 indicates high utility.
Note
The current implementation hard-codes the downstream task to next-visit prediction (built via
build_next_visit_prediction_dataset()). This is degenerate for bag-of-codes generators such as MedGAN and CorGAN, which emit a single aggregate visit per patient and so always get label=0. A future revision will let callers plug in static-label tasks (mortality, readmission, “ever diagnosed with X”, …) so MLE is meaningful for both sequential (HALO, GPT2, PromptEHR) and bag-of-codes (MedGAN, CorGAN) generators.- Parameters:
train_fn (
Callable) – A training function such aspyhealth.metrics.generative.utils.train_lstm_model()ortrain_sklearn_model, returning(model, y_true, y_pred).train_ehr (
DataFrame) – Real training EHR dataframe, flat[id, time, visit_codes, labels]format.test_ehr (
DataFrame) – Real held-out test EHR dataframe; same schema astrain_ehr.syn_ehr (
DataFrame) – Synthetic EHR dataframe; same schema astrain_ehr.subject_col (
str) – Column name for patient/subject identifiers.visit_col (
str) – Column name for visit/timestep identifiers.code_col (
str) – Column name for the medical codes (one code per row).label_col (
str) – Column name for the label (overwritten by the next-visit prediction label).n_bootstraps (
int) – Number of bootstrap resamples of the predictions.**kwargs – Extra keyword arguments forwarded to
train_fn.
- Return type:
- Returns:
- Dictionary mapping the MLE metrics (real/synthetic accuracy and F1,
their difference and ratio) to their
(mean, std)across bootstraps.
Examples
>>> from pyhealth.metrics.generative.utility import compute_mle >>> from pyhealth.metrics.generative.utils import train_lstm_model >>> # train_ehr, test_ehr, syn_ehr are flat >>> # [id, time, visit_codes, labels] dataframes sharing one schema -- >>> # see evaluate_synthetic_ehr for how to build them. >>> result = compute_mle(train_lstm_model, train_ehr, test_ehr, syn_ehr) >>> synth_acc_mean, synth_acc_std = result["MLE_Synth_Accuracy"]
- pyhealth.metrics.generative.compute_prevalence_metrics(train_ehr, syn_ehr, subject_col='id', code_col='visit_codes', n_bootstraps=5)[source]#
Compares per-code patient-level prevalence of real vs synthetic data.
For every code, prevalence is the fraction of unique patients who have that code at least once. The real and synthetic prevalence vectors are compared with R-squared, Pearson correlation and RMSE; bootstrap resampling is over codes.
This metric only reads
subject_colandcode_col, buttrain_ehrandsyn_ehrare expected to be the same flat[id, time, visit_codes, labels]frames used by the other metrics.- Parameters:
train_ehr (
DataFrame) – Real training EHR dataframe, flat[id, time, visit_codes, labels]format.syn_ehr (
DataFrame) – Synthetic EHR dataframe; same schema astrain_ehr.subject_col (
str) – Column name for patient/subject identifiers.code_col (
str) – Column name for the medical codes (one code per row).n_bootstraps (
int) – Number of bootstrap resamples over codes.
- Return type:
- Returns:
- Dictionary mapping
"Prevalence_R2","Prevalence_Pearson"and "Prevalence_RMSE"to their(mean, std)across bootstraps.
- Dictionary mapping
Examples
>>> from pyhealth.metrics.generative.utility import ( ... compute_prevalence_metrics, ... ) >>> # train_ehr and syn_ehr are flat [id, time, visit_codes, labels] >>> # dataframes sharing one schema -- see evaluate_synthetic_ehr for >>> # how to build them. >>> result = compute_prevalence_metrics(train_ehr, syn_ehr) >>> r2_mean, r2_std = result["Prevalence_R2"]