Source code for pyhealth.models.generators.medgan

"""MedGAN: Medical Generative Adversarial Network for synthetic EHR generation.

This is a port of the reference implementation
(https://github.com/mp2893/medgan and the PyTorch reimplementation under
``reference/cor-gan/Generative/medGAN/MIMIC/pytorch/MLP/medGAN.py``) wrapped
as a PyHealth ``BaseModel`` so it consumes the standard
``dataset -> SampleDataset -> model`` pipeline.

MedGAN treats each patient as a flat bag-of-codes (no visit structure), so it
expects an input feature named ``visits`` backed by a ``MultiHotProcessor``.
The training procedure has two phases (mirroring the reference):

* a **linear autoencoder** is pre-trained with binary cross-entropy
  reconstruction loss, and
* an **adversarial training** phase where the generator emits latent codes,
  the autoencoder's decoder projects them back to a multi-hot patient vector,
  and a discriminator with optional minibatch averaging tries to distinguish
  real from synthetic.

The ``MedGANAutoencoder``, ``MedGANGenerator`` and ``MedGANDiscriminator``
modules below mirror the reference. The public ``MedGAN`` class follows the
same API style as :class:`pyhealth.models.generators.HALO`
(``train_model`` / ``generate`` / ``save_model`` / ``load_model``).
"""

import os
from typing import Dict, List, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, RandomSampler
from tqdm import tqdm

from pyhealth.models import BaseModel


# ----------------------------------------------------------------------------
# Building blocks (ported from reference medgan.py / PyTorch reimplementation)
# ----------------------------------------------------------------------------
class _MultiHotDataset(Dataset):
    """Tiny ``torch.utils.data.Dataset`` over a multi-hot numpy matrix."""

    def __init__(self, data: np.ndarray):
        self.data = data.astype(np.float32)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx):
        return torch.from_numpy(self.data[idx])


class MedGANAutoencoder(nn.Module):
    """Linear autoencoder for MedGAN pretraining.

    Mirrors the reference single-layer encoder/decoder
    (``Linear -> Tanh`` and ``Linear -> Sigmoid``).

    Args:
        input_dim: Vocabulary size (number of distinct codes).
        embedding_dim: Latent dimensionality. Default: 128.
    """

    def __init__(self, input_dim: int, embedding_dim: int = 128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, embedding_dim),
            nn.Tanh(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, input_dim),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)


class MedGANGenerator(nn.Module):
    """Two-layer MLP generator with residual connections (per reference)."""

    def __init__(self, latent_dim: int = 128, hidden_dim: int = 128):
        super().__init__()
        self.linear1 = nn.Linear(latent_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01)
        self.act1 = nn.ReLU()

        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim, eps=0.001, momentum=0.01)
        self.act2 = nn.Tanh()

    def forward(self, x):
        residual = x
        out = self.act1(self.bn1(self.linear1(x))) + residual

        residual = out
        out = self.act2(self.bn2(self.linear2(out))) + residual
        return out


class MedGANDiscriminator(nn.Module):
    """MLP discriminator with optional minibatch averaging (per reference)."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 256,
        minibatch_averaging: bool = True,
    ):
        super().__init__()
        self.minibatch_averaging = minibatch_averaging
        model_input_dim = input_dim * 2 if minibatch_averaging else input_dim

        self.model = nn.Sequential(
            nn.Linear(model_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        if self.minibatch_averaging:
            # Average over the batch and concatenate to each sample, exactly
            # as in the reference (medGAN.py).
            x_mean = torch.mean(x, dim=0).repeat(x.shape[0], 1)
            x = torch.cat((x, x_mean), dim=1)
        return self.model(x)


def _weights_init(m):
    """Xavier-uniform for Linear, N(1, 0.02) gamma / 0 beta for BatchNorm."""
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm1d):
        nn.init.normal_(m.weight, mean=1.0, std=0.02)
        nn.init.constant_(m.bias, 0)


def _autoencoder_loss(x_output, y_target):
    """Sparse-friendly BCE: sum over features, mean over batch.

    Equivalent to ``BCELoss(reduction='sum') / batch_size`` and matches the
    reference; ``BCELoss(reduction='mean')`` would also mean over features
    which dilutes the signal for sparse code vectors.
    """
    epsilon = 1e-12
    term = y_target * torch.log(x_output + epsilon) + (
        1.0 - y_target
    ) * torch.log(1.0 - x_output + epsilon)
    return torch.mean(-torch.sum(term, dim=1), dim=0)


# ----------------------------------------------------------------------------
# PyHealth BaseModel wrapper
# ----------------------------------------------------------------------------
[docs]class MedGAN(BaseModel): """MedGAN synthetic-EHR generator, wrapped as a PyHealth ``BaseModel``. Generates synthetic binary EHR records via the two-phase procedure from Choi et al. (MLHC 2017): pretrain a linear autoencoder, then run BCE-GAN adversarial training where the generator maps noise to the autoencoder's latent space and the decoder projects back to a multi-hot patient vector. Generation is **unconditional**: each synthetic patient is a flat bag of codes (no visit structure), matching the ``multi_hot`` input schema. Args: dataset: A fitted ``SampleDataset`` whose ``input_schema`` contains ``{"visits": "multi_hot"}`` and whose ``output_schema`` is empty. latent_dim: Generator noise dimensionality. Default: 128. The generator's residual connection requires ``latent_dim == hidden_dim``; if they differ, ``latent_dim`` is silently aligned to ``hidden_dim``. hidden_dim: Generator hidden width (also the autoencoder embedding dimension). Default: 128. discriminator_hidden_dim: Discriminator hidden width. Default: 256. minibatch_averaging: Concatenate per-batch mean to each discriminator input. Default: True. batch_size: Training batch size. Default: 512. ae_epochs: Autoencoder pre-training epochs. Default: 100. gan_epochs: Adversarial training epochs. Default: 200. ae_lr: Autoencoder learning rate. Default: 1e-3. gan_lr: GAN learning rate. Default: 1e-3. save_dir: Checkpoint directory used by ``train_model``. Default: ``"./save/medgan/"``. Examples: >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... {"patient_id": "p1", "visits": ["A", "B", "C"]}, ... {"patient_id": "p2", "visits": ["A", "C", "D"]}, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"visits": "multi_hot"}, ... output_schema={}, ... ) >>> model = MedGAN(dataset, latent_dim=16, hidden_dim=16, batch_size=2) >>> isinstance(model, MedGAN) True """ def __init__( self, dataset, latent_dim: int = 128, hidden_dim: int = 128, discriminator_hidden_dim: int = 256, minibatch_averaging: bool = True, batch_size: int = 512, ae_epochs: int = 100, gan_epochs: int = 200, ae_lr: float = 1e-3, gan_lr: float = 1e-3, save_dir: str = "./save/medgan/", ) -> None: super().__init__(dataset) if "visits" not in dataset.input_processors: raise ValueError( "MedGAN expects an input feature named 'visits' backed by a " "MultiHotProcessor." ) # The generator's residual connection (``out + residual`` with # ``residual`` being the noise input) requires latent_dim == hidden_dim. # Align silently if the user mismatched, mirroring CorGAN. if latent_dim != hidden_dim: latent_dim = hidden_dim self.latent_dim = latent_dim self.hidden_dim = hidden_dim self._batch_size = batch_size self._ae_epochs = ae_epochs self._gan_epochs = gan_epochs self._ae_lr = ae_lr self._gan_lr = gan_lr self.save_dir = save_dir # Code vocab from the MultiHotProcessor's label_vocab. self.visits_processor = dataset.input_processors["visits"] self.input_dim = self.visits_processor.size() self._idx_to_code: List[Optional[str]] = [None] * self.input_dim for code, idx in self.visits_processor.label_vocab.items(): self._idx_to_code[idx] = code self.autoencoder = MedGANAutoencoder( input_dim=self.input_dim, embedding_dim=hidden_dim, ) self.generator = MedGANGenerator( latent_dim=latent_dim, hidden_dim=hidden_dim, ) self.discriminator = MedGANDiscriminator( input_dim=self.input_dim, hidden_dim=discriminator_hidden_dim, minibatch_averaging=minibatch_averaging, ) self.autoencoder.apply(_weights_init) self.generator.apply(_weights_init) self.discriminator.apply(_weights_init) # ------------------------------------------------------------------ # forward -- required by BaseModel # ------------------------------------------------------------------
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """MedGAN does not have a single supervised forward pass. Use :meth:`train_model` for training and :meth:`generate` for synthesis. ``forward`` is implemented only to satisfy the ``BaseModel`` abstract contract. """ raise NotImplementedError( "MedGAN is a GAN: use train_model() and generate() instead of " "forward()." )
# ------------------------------------------------------------------ # Custom training loop # ------------------------------------------------------------------ @staticmethod def _resolve_device(device=None) -> torch.device: """Resolve a user-supplied device, defaulting to CUDA when available.""" if device is None: return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(device) def _build_dataloader(self, dataset) -> DataLoader: """Stack the multi-hot tensors of ``dataset`` into a DataLoader. The fitted ``MultiHotProcessor`` has already converted each patient's ``visits`` field into a ``(input_dim,)`` float32 tensor, so we can simply stack and wrap. """ tensors = [dataset[i]["visits"] for i in range(len(dataset))] matrix = torch.stack(tensors).numpy() wrapped = _MultiHotDataset(matrix) sampler = RandomSampler(wrapped, replacement=True) return DataLoader( wrapped, batch_size=self._batch_size, shuffle=False, num_workers=0, drop_last=True, sampler=sampler, )
[docs] def train_model(self, train_dataset, val_dataset=None, device=None) -> None: """Train MedGAN with a custom two-phase loop. Named ``train_model`` (not ``train``) to avoid shadowing ``nn.Module.train()``. Phase 1 pre-trains the autoencoder with sparse-friendly BCE reconstruction loss; phase 2 runs standard BCE-GAN adversarial training where the generator+decoder are optimised against a binary discriminator. Args: train_dataset: ``SampleDataset`` for training. val_dataset: Unused; accepted for API symmetry with other PyHealth trainers. device: Device to train on (``"cuda"``, ``"cpu"``, etc.). If ``None``, uses CUDA when available. """ device = self._resolve_device(device) self.to(device) print(f"Training MedGAN on: {device}") os.makedirs(self.save_dir, exist_ok=True) dataloader = self._build_dataloader(train_dataset) # ---- Phase 1: Autoencoder pretraining ---- optimizer_ae = torch.optim.Adam( self.autoencoder.parameters(), lr=self._ae_lr ) for epoch in tqdm(range(self._ae_epochs), desc="AE pretrain"): self.autoencoder.train() total_loss, n_batches = 0.0, 0 for batch in dataloader: real = batch.to(self.device) recon = self.autoencoder(real) loss = _autoencoder_loss(recon, real) optimizer_ae.zero_grad() loss.backward() optimizer_ae.step() total_loss += loss.item() n_batches += 1 # ---- Phase 2: Adversarial training ---- # Generator + the autoencoder's decoder are trained jointly, matching # the reference (the decoder is what makes synthetic samples valid). optimizer_g = torch.optim.Adam( list(self.generator.parameters()) + list(self.autoencoder.decoder.parameters()), lr=self._gan_lr, ) optimizer_d = torch.optim.Adam( self.discriminator.parameters(), lr=self._gan_lr ) best_d_loss = float("inf") for epoch in tqdm(range(self._gan_epochs), desc="GAN train"): self.generator.train() self.discriminator.train() self.autoencoder.eval() self.autoencoder.decoder.train() epoch_d_loss, epoch_g_loss, n_batches = 0.0, 0.0, 0 for batch in dataloader: real = batch.to(self.device) bs = real.size(0) # --- Train Discriminator --- optimizer_d.zero_grad() noise = torch.randn(bs, self.latent_dim, device=self.device) fake = self.autoencoder.decode(self.generator(noise)) real_pred = self.discriminator(real) fake_pred = self.discriminator(fake.detach()) d_loss = F.binary_cross_entropy( real_pred, torch.ones_like(real_pred) ) + F.binary_cross_entropy( fake_pred, torch.zeros_like(fake_pred) ) d_loss.backward() optimizer_d.step() # --- Train Generator (+ decoder) --- optimizer_g.zero_grad() fake_pred = self.discriminator(fake) g_loss = F.binary_cross_entropy( fake_pred, torch.ones_like(fake_pred) ) g_loss.backward() optimizer_g.step() epoch_d_loss += d_loss.item() epoch_g_loss += g_loss.item() n_batches += 1 avg_d = epoch_d_loss / max(n_batches, 1) if avg_d < best_d_loss: best_d_loss = avg_d self.save_model(os.path.join(self.save_dir, "best.pt")) self.save_model(os.path.join(self.save_dir, "final.pt"))
# ------------------------------------------------------------------ # Synthesis # ------------------------------------------------------------------
[docs] def generate( self, num_samples: int, random_sampling: bool = False, device=None, ) -> List[Dict]: """Generate synthetic patient records. Each synthetic patient is decoded from a generated multi-hot vector by thresholding (or, optionally, Bernoulli sampling) at 0.5 and mapping the indices back to code strings. Args: num_samples: Number of synthetic patients to generate. random_sampling: If True, Bernoulli-sample the decoder output; otherwise threshold at 0.5 (the reference's behaviour). Default: False. device: Device to generate on. If ``None``, uses CUDA when available. Returns: List of dicts ``{"patient_id": "synthetic_i", "visits": [[code, ...]]}``. ``visits`` is a list containing a **single** visit (matching HALO's nested-list output structure). MedGAN is a bag-of-codes model -- following the reference ``process_mimic.py``, each patient is represented by the union of codes across all of their historical visits -- so the single inner list is that aggregate bag. The inner list may be empty if the generator produced an all-zero vector. """ device = self._resolve_device(device) self.to(device) self.generator.eval() self.autoencoder.eval() bs = min(self._batch_size, max(num_samples, 1)) rows = np.zeros((num_samples, self.input_dim), dtype=np.float32) pbar = tqdm(total=num_samples, desc="Generating patients") with torch.no_grad(): i = 0 while i < num_samples: cur = min(bs, num_samples - i) z = torch.randn(cur, self.latent_dim, device=self.device) probs = self.autoencoder.decode(self.generator(z)) if random_sampling: sample = torch.bernoulli(probs) else: sample = (probs >= 0.5).float() rows[i : i + cur] = sample.cpu().numpy() i += cur pbar.update(cur) pbar.close() results: List[Dict] = [] for i in range(num_samples): codes = [ self._idx_to_code[idx] for idx in np.nonzero(rows[i])[0] if self._idx_to_code[idx] not in (None, "<pad>", "<unk>") ] # Wrap in a single-visit list to mirror HALO's nested output. # MedGAN models the patient as one aggregate bag of codes. results.append({"patient_id": f"synthetic_{i}", "visits": [codes]}) return results
# ------------------------------------------------------------------ # Checkpoint I/O # ------------------------------------------------------------------
[docs] def save_model(self, path: str) -> None: """Save weights and the code vocabulary needed for decoding.""" torch.save( { "autoencoder": self.autoencoder.state_dict(), "generator": self.generator.state_dict(), "discriminator": self.discriminator.state_dict(), "input_dim": self.input_dim, "latent_dim": self.latent_dim, "idx_to_code": self._idx_to_code, }, path, )
[docs] def load_model(self, path: str) -> None: """Load weights and the code vocabulary from a checkpoint.""" ckpt = torch.load(path, map_location=self.device) self.autoencoder.load_state_dict(ckpt["autoencoder"]) self.generator.load_state_dict(ckpt["generator"]) self.discriminator.load_state_dict(ckpt["discriminator"]) if "idx_to_code" in ckpt: self._idx_to_code = ckpt["idx_to_code"]