Source code for pyhealth.models.generators.promptehr

"""PromptEHR: prompt-learning BART for synthetic EHR generation.

This is a PyHealth ``BaseModel`` port of PromptEHR (Wang & Sun, EMNLP'22,
https://github.com/RyanWangZf/PromptEHR), wrapped so it consumes the standard
``dataset -> set_task -> SampleDataset -> model`` pipeline and shares the same
:class:`~pyhealth.tasks.EHRGeneration` task as
:class:`~pyhealth.models.HALO` and :class:`~pyhealth.models.GPT2`.

PromptEHR treats sequential EHRs as a *neural database* and learns to fill in
patient records with a conditional **BART** (sequence-to-sequence denoising
autoencoder) trained with **prompt learning**. The three ideas that define the
reference implementation are preserved here:

* **BART seq2seq core.** Generation is encoder-decoder, not decoder-only. The
  reference subclasses ``BartForEHRSimulation`` from ``BartPretrainedModel``;
  this port wraps :class:`transformers.BartForConditionalGeneration`, mirroring
  the way :class:`~pyhealth.models.GPT2` wraps ``GPT2LMHeadModel``.
* **Prompt learning.** The reference reparameterizes a learnable prompt from
  patient baseline demographics and prepends it to the encoder/decoder
  (``ConditionalPrompt``). PyHealth's :class:`~pyhealth.tasks.EHRGeneration`
  task is *unconditional* (only ``visits``, no baseline features -- exactly like
  HALO/GPT2), so the prompt reduces to a learnable continuous **soft prefix**
  prepended to the encoder. This is the prompt-tuning core without the
  demographic reparameterization.
* **Span-infilling objective.** The reference learns by masking spans of codes
  and reconstructing them. Here the encoder sees a BART-style span-infilled
  copy of the patient's code stream -- random non-overlapping spans with
  lengths drawn from ``Poisson(mean_span_len)`` are each replaced by a single
  ``[MASK]`` sentinel until roughly ``mask_prob`` of the stream is covered --
  and the decoder reconstructs the full stream. This matches the original BART
  text-infilling objective used by PromptEHR rather than per-token masking.

Each patient's visits are serialized into a single code stream::

    [CODE_PROMPT]  <codes of visit 1>  [VISIT_DELIM]  <codes of visit 2>  ...  [EOS]

The reference handles several code types (diagnosis / procedure / drug / lab)
each with its own modality prompt token; the PyHealth ``EHRGeneration`` task
exposes a single ``visits`` modality, so a single ``[CODE_PROMPT]`` token marks
it. The code vocabulary is taken from the dataset's
``NestedSequenceProcessor`` (which already reserves index 0 for ``<pad>`` and
index 1 for ``<unk>``); five special tokens (BOS, EOS, VISIT_DELIM, MASK,
CODE_PROMPT) are appended, and ``<pad>`` (index 0) is reused as the pad token.
"""

import os
from typing import Dict, List, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import BartConfig, BartForConditionalGeneration

from pyhealth.datasets import get_dataloader
from pyhealth.models import BaseModel


[docs]class PromptEHR(BaseModel): """PromptEHR synthetic-EHR generator, wrapped as a PyHealth ``BaseModel``. Trains a BART denoising autoencoder with a learnable soft prompt on patient visit-code streams, then generates synthetic patients by prompt-conditioned encoder-decoder sampling. Generation is **unconditional** (no demographic conditioning), matching the :class:`~pyhealth.tasks.EHRGeneration` task. Args: dataset: A fitted ``SampleDataset`` whose ``input_schema`` contains ``{"visits": NestedSequenceProcessor}`` and whose ``output_schema`` is empty. embed_dim: BART model dimension (``d_model``). Must be divisible by ``n_heads``. Default: 256. n_heads: Number of attention heads (encoder and decoder). Default: 8. n_layers: Number of encoder and decoder layers each. Default: 6. ffn_dim: Feed-forward dimension. Default: 4 * ``embed_dim``. prompt_length: Number of learnable soft-prompt positions prepended to the encoder. Default: 8. max_len: Maximum code-stream length (``max_position_embeddings``); streams are truncated to this length. Default: 512. mask_prob: Target fraction of the (non-sentinel) code stream covered by masked spans in the encoder input. Default: 0.15. mean_span_len: Mean of the Poisson distribution used to sample span lengths for BART-style span infilling. Default: 3.0. batch_size: Training batch size. Default: 16. epochs: Number of training epochs. Default: 50. lr: Learning rate for the Adam optimizer. Default: 1e-4. save_dir: Directory for checkpoints written by ``train_model``. Default: ``"./save/"``. Examples: >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... {"patient_id": "p1", "visits": [["A", "B"], ["C"]]}, ... {"patient_id": "p2", "visits": [["A"], ["B", "C"]]}, ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"visits": "nested_sequence"}, ... output_schema={}, ... ) >>> model = PromptEHR( ... dataset, embed_dim=16, n_heads=2, n_layers=2, max_len=64 ... ) >>> isinstance(model, PromptEHR) True """ def __init__( self, dataset, embed_dim: int = 256, n_heads: int = 8, n_layers: int = 6, ffn_dim: Optional[int] = None, prompt_length: int = 8, max_len: int = 512, mask_prob: float = 0.15, mean_span_len: float = 3.0, batch_size: int = 16, epochs: int = 50, lr: float = 1e-4, save_dir: str = "./save/", ) -> None: super(PromptEHR, self).__init__(dataset) if "visits" not in dataset.input_processors: raise ValueError( "PromptEHR expects an input feature named 'visits' backed by a " "NestedSequenceProcessor." ) self.save_dir = save_dir self._batch_size = batch_size self._epochs = epochs self._lr = lr self.max_len = max_len self.mask_prob = mask_prob self.mean_span_len = mean_span_len self.prompt_length = prompt_length # Code vocab from the NestedSequenceProcessor (includes <pad>=0, <unk>=1). self.visits_processor = dataset.input_processors["visits"] self.code_vocab_size = self.visits_processor.vocab_size() # Append five special tokens after the code vocab; reuse <pad>=0 as PAD. self.bos_id = self.code_vocab_size self.eos_id = self.code_vocab_size + 1 self.delim_id = self.code_vocab_size + 2 # visit separator self.mask_id = self.code_vocab_size + 3 # denoising mask token self.code_prompt_id = self.code_vocab_size + 4 # modality prompt token self.pad_id = 0 total_vocab_size = self.code_vocab_size + 5 ffn_dim = ffn_dim if ffn_dim is not None else 4 * embed_dim config = BartConfig( vocab_size=total_vocab_size, max_position_embeddings=max_len, d_model=embed_dim, encoder_layers=n_layers, decoder_layers=n_layers, encoder_attention_heads=n_heads, decoder_attention_heads=n_heads, encoder_ffn_dim=ffn_dim, decoder_ffn_dim=ffn_dim, pad_token_id=self.pad_id, bos_token_id=self.bos_id, eos_token_id=self.eos_id, decoder_start_token_id=self.eos_id, # BART convention forced_bos_token_id=None, forced_eos_token_id=None, ) # Registered as sub-modules so .parameters()/.to() work. self.bart = BartForConditionalGeneration(config) # Learnable soft prompt (prompt learning), prepended to the encoder. self.soft_prompt = nn.Parameter(torch.zeros(prompt_length, embed_dim)) nn.init.normal_(self.soft_prompt, std=config.init_std) # ------------------------------------------------------------------ @staticmethod def _resolve_device(device=None) -> torch.device: """Resolve a user-supplied device, defaulting to CUDA when available.""" if device is None: return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device(device) # ------------------------------------------------------------------ # Visit index tensor -> denoising seq2seq tensors # ------------------------------------------------------------------ def _serialize(self, visits: torch.Tensor) -> List[List[int]]: """Serialize each patient's visits into a flat code stream. Layout (decoder target): ``[CODE_PROMPT] codes_v1 [VS] codes_v2 ... [EOS]``. Index 0 (``<pad>``) is skipped. """ streams: List[List[int]] = [] for i in range(visits.shape[0]): n_visits = int((visits[i].sum(dim=-1) > 0).sum().item()) seq: List[int] = [self.code_prompt_id] for j in range(n_visits): codes = [int(c) for c in visits[i, j].tolist() if c > 0] seq.extend(codes) if j < n_visits - 1: seq.append(self.delim_id) seq.append(self.eos_id) # Truncate but always keep the trailing EOS. if len(seq) > self.max_len: seq = seq[: self.max_len - 1] + [self.eos_id] streams.append(seq) return streams def _corrupt(self, stream: List[int]) -> List[int]: """Build the encoder input: BOS + a BART span-infilled copy of the stream. Selects random non-overlapping spans inside the stream (excluding the leading ``[CODE_PROMPT]`` modality marker and the trailing ``[EOS]``) with lengths drawn from ``Poisson(mean_span_len)`` until roughly ``mask_prob`` of the stream is covered, then replaces each span with a single ``[MASK]`` sentinel. Visit separators and code tokens are both eligible for masking, matching the original BART text-infilling objective used by PromptEHR. """ bos = [self.bos_id] n = len(stream) # Inner range excludes the leading [CODE_PROMPT] and trailing [EOS]. inner_lo, inner_hi = 1, n - 1 inner_len = inner_hi - inner_lo if inner_len <= 0: return bos + list(stream) target_masked = int(round(self.mask_prob * inner_len)) if target_masked <= 0: return bos + list(stream) spans: List[tuple] = [] # (start, end), half-open, in stream coords masked = 0 # Cap attempts to avoid pathological loops on tiny / fully-packed streams. for _ in range(4 * inner_len): if masked >= target_masked: break span_len = max(1, int(np.random.poisson(self.mean_span_len))) span_len = min(span_len, inner_len) start = int(np.random.randint(inner_lo, inner_hi - span_len + 1)) end = start + span_len if any(start < e and end > s for (s, e) in spans): continue spans.append((start, end)) masked += span_len if not spans: return bos + list(stream) spans.sort() out: List[int] = bos + list(stream[:inner_lo]) cur = inner_lo for (s, e) in spans: out.extend(stream[cur:s]) out.append(self.mask_id) cur = e out.extend(stream[cur:inner_hi]) out.extend(stream[inner_hi:]) return out def _encode_batch(self, visits: torch.Tensor): """Convert padded visit indices to encoder inputs and decoder labels. Returns: enc_input_ids: LongTensor ``(batch, L_enc)`` corrupted streams. enc_attention_mask: LongTensor ``(batch, L_enc)``. labels: LongTensor ``(batch, L_dec)`` full streams, padding -> -100. """ streams = self._serialize(visits) enc_streams = [self._corrupt(s) for s in streams] enc_input_ids = self._pad_stack(enc_streams, self.pad_id) enc_attention_mask = (enc_input_ids != self.pad_id).long() # Position 0 is BOS, never masked out by the pad check; force it on. enc_attention_mask[:, 0] = 1 labels = self._pad_stack(streams, self.pad_id) labels[labels == self.pad_id] = -100 return enc_input_ids, enc_attention_mask, labels def _pad_stack(self, seqs: List[List[int]], pad_value: int) -> torch.Tensor: """Right-pad a list of int lists into a 2D LongTensor on ``self.device``.""" length = max(len(s) for s in seqs) out = torch.full( (len(seqs), length), pad_value, dtype=torch.long, device=self.device ) for i, s in enumerate(seqs): out[i, : len(s)] = torch.tensor(s, device=self.device) return out def _encoder_inputs_embeds(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): """Prepend the learnable soft prompt to the encoder token embeddings. Returns the prompt-augmented ``inputs_embeds`` and the matching attention mask (soft-prompt positions are always attended to). """ token_embeds = self.bart.get_input_embeddings()(input_ids) bsz = input_ids.shape[0] prompt = self.soft_prompt.unsqueeze(0).expand(bsz, -1, -1) inputs_embeds = torch.cat([prompt, token_embeds], dim=1) prompt_mask = torch.ones( bsz, self.prompt_length, dtype=attention_mask.dtype, device=self.device ) attention_mask = torch.cat([prompt_mask, attention_mask], dim=1) return inputs_embeds, attention_mask # ------------------------------------------------------------------ # forward -- required by BaseModel # ------------------------------------------------------------------
[docs] def forward(self, visits: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: """Forward pass (denoising seq2seq reconstruction). Args: visits: LongTensor ``(batch, max_visits, max_codes_per_visit)`` from the ``NestedSequenceProcessor``. **kwargs: Any other batch keys are ignored. Returns: Dict with ``loss`` (scalar seq2seq cross-entropy) and ``y_prob`` (decoder next-token probabilities, shape ``(batch, L_dec, vocab)``). """ visits = visits.to(self.device) enc_input_ids, enc_attention_mask, labels = self._encode_batch(visits) inputs_embeds, enc_attention_mask = self._encoder_inputs_embeds( enc_input_ids, enc_attention_mask ) out = self.bart( inputs_embeds=inputs_embeds, attention_mask=enc_attention_mask, labels=labels, ) return {"loss": out.loss, "y_prob": F.softmax(out.logits, dim=-1)}
# ------------------------------------------------------------------ # Custom training loop # ------------------------------------------------------------------
[docs] def train_model(self, train_dataset, val_dataset=None, device=None) -> None: """Train PromptEHR with a custom loop. Named ``train_model`` (not ``train``) to avoid shadowing ``nn.Module.train()``. Uses the standard ``get_dataloader``, an Adam optimizer, and the BART denoising loss. When ``val_dataset`` is given, validation loss is computed after each epoch and the best checkpoint is saved to ``self.save_dir``. Args: train_dataset: ``SampleDataset`` for training. val_dataset: Optional ``SampleDataset`` for validation. device: Device to train on, e.g. ``"cuda"``, ``"cuda:1"``, or ``"cpu"``. If ``None`` (default), uses CUDA when available and falls back to CPU. """ device = self._resolve_device(device) self.to(device) print(f"Training on: {device}") os.makedirs(self.save_dir, exist_ok=True) optimizer = torch.optim.Adam(self.parameters(), lr=self._lr) checkpoint_path = os.path.join(self.save_dir, "promptehr_model") if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=self.device) self.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) train_loader = get_dataloader( train_dataset, batch_size=self._batch_size, shuffle=True ) global_loss = 1e10 for epoch in tqdm(range(self._epochs), desc="Epochs"): self.bart.train() batch_iter = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False) for batch in batch_iter: visits = batch["visits"].to(self.device) optimizer.zero_grad() ret = self.forward(visits=visits) loss = ret["loss"] loss.backward() optimizer.step() batch_iter.set_postfix(loss=f"{loss.item():.4f}") if val_dataset is not None: self.bart.eval() val_loader = get_dataloader( val_dataset, batch_size=self._batch_size, shuffle=False ) val_losses = [] with torch.no_grad(): for val_batch in val_loader: visits = val_batch["visits"].to(self.device) val_losses.append(self.forward(visits=visits)["loss"].item()) cur_val_loss = float(np.mean(val_losses)) print(f"Epoch {epoch} Validation Loss: {cur_val_loss:.7f}") if cur_val_loss < global_loss: global_loss = cur_val_loss state = { "model": self.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, } torch.save(state, checkpoint_path) print("------------ Save best model ------------")
# ------------------------------------------------------------------ # Synthesis # ------------------------------------------------------------------ def _decode_ids(self, ids: List[int], index_to_code: Dict[int, str]) -> List[List[str]]: """Decode a generated decoder token stream into per-visit code lists.""" visits_out: List[List[str]] = [] current: List[str] = [] for tid in ids: if tid in (self.bos_id, self.pad_id, self.code_prompt_id, self.mask_id): continue if tid == self.eos_id: break if tid == self.delim_id: if current: visits_out.append(current) current = [] continue if tid < self.code_vocab_size: code = index_to_code.get(int(tid)) if code not in (None, "<pad>", "<unk>"): current.append(code) if current: visits_out.append(current) return visits_out
[docs] def generate( self, num_samples: int, device=None, top_k: int = 50, top_p: float = 0.95, ) -> List[Dict]: """Generate synthetic patients with the trained PromptEHR model. Feeds the encoder a fully-masked seed stream (so generation is driven by the learned soft prompt), precomputes the prompt-augmented encoder states, and autoregressively samples a decoder stream with ``top_k``/``top_p`` sampling, then decodes it into per-visit code lists. Args: num_samples: Number of synthetic patients to generate. device: Device to generate on, e.g. ``"cuda"``, ``"cuda:1"``, or ``"cpu"``. If ``None`` (default), uses CUDA when available and falls back to CPU. top_k: Top-k sampling cutoff. Default: 50. top_p: Nucleus (top-p) sampling cutoff. Default: 0.95. Returns: List of dicts, each ``{"patient_id": "synthetic_i", "visits": [[code, ...], ...]}`` with decoded code strings. """ device = self._resolve_device(device) self.to(device) index_to_code = {v: k for k, v in self.visits_processor.code_vocab.items()} self.bart.eval() synthetic_dataset: List[Dict] = [] sample_batch_size = min(num_samples, 256) generated = 0 pbar = tqdm(total=num_samples, desc="Generating patients") # Fully-masked seed: [BOS] [CODE_PROMPT] [MASK] [EOS]. seed = [self.bos_id, self.code_prompt_id, self.mask_id, self.eos_id] with torch.no_grad(): while generated < num_samples: bs = min(sample_batch_size, num_samples - generated) enc_input_ids = torch.tensor( [seed] * bs, dtype=torch.long, device=self.device ) enc_attention_mask = torch.ones_like(enc_input_ids) inputs_embeds, enc_attention_mask = self._encoder_inputs_embeds( enc_input_ids, enc_attention_mask ) encoder_outputs = self.bart.get_encoder()( inputs_embeds=inputs_embeds, attention_mask=enc_attention_mask, return_dict=True, ) out_ids = self.bart.generate( encoder_outputs=encoder_outputs, attention_mask=enc_attention_mask, max_length=self.max_len, do_sample=True, top_k=top_k, top_p=top_p, num_beams=1, pad_token_id=self.pad_id, eos_token_id=self.eos_id, decoder_start_token_id=self.eos_id, ) for i in range(bs): # BART's generate prepends decoder_start_token_id (= eos_id) # at position 0; skip it so the real eos terminates decoding. visits_out = self._decode_ids( out_ids[i].tolist()[1:], index_to_code ) synthetic_dataset.append( { "patient_id": f"synthetic_{generated + i}", "visits": visits_out, } ) generated += bs pbar.update(bs) pbar.close() return synthetic_dataset