"""HALO: Hierarchical Autoregressive Language mOdel for synthetic EHR generation.
This is a faithful port of the reference implementation
(https://github.com/Brandon-Theodorou/HALO_Inpatient) wrapped as a PyHealth
``BaseModel`` so it consumes the standard
``dataset -> set_task -> SampleDataset -> model`` pipeline.
HALO is a two-level model:
* a GPT-2-style **coarse** transformer operates over visit-level multi-hot
vectors, and
* a **fine** autoregressive head predicts the (multi-label) set of codes within
each visit.
The transformer/head classes below (``LayerNorm``, ``Conv1D``, ``Attention``,
``MLP``, ``Block``, ``CoarseTransformerModel``, ``AutoregressiveLinear``,
``FineAutoregressiveHead``, ``HALOModel``) are ported verbatim from the
reference ``model.py``. The only behavioural change is that PyHealth's HALO is
**unconditional** (``label_vocab_size = 0``): it generates visit-code sequences
without conditioning on CCS labels.
"""
import copy
import math
import os
from typing import Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from pyhealth.datasets import get_dataloader
from pyhealth.models import BaseModel
# ----------------------------------------------------------------------------
# Configuration (plain class, not a dataclass; mirrors reference config.py)
# ----------------------------------------------------------------------------
class HALOConfig:
"""Hyperparameter container for the HALO transformer.
Kept as a plain class with explicit ``__init__`` assignments (matching the
reference ``config.py``) so the low-level modules can read attributes such
as ``config.n_embd``.
"""
def __init__(
self,
total_vocab_size: int,
code_vocab_size: int,
label_vocab_size: int = 0,
special_vocab_size: int = 3,
n_positions: int = 56,
n_ctx: int = 48,
n_embd: int = 768,
n_layer: int = 12,
n_head: int = 12,
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
batch_size: int = 48,
epoch: int = 50,
pos_loss_weight: Optional[float] = None,
lr: float = 1e-4,
) -> None:
self.total_vocab_size = total_vocab_size
self.code_vocab_size = code_vocab_size
self.label_vocab_size = label_vocab_size
self.special_vocab_size = special_vocab_size
self.n_positions = n_positions
self.n_ctx = n_ctx
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.batch_size = batch_size
self.epoch = epoch
self.pos_loss_weight = pos_loss_weight
self.lr = lr
# ----------------------------------------------------------------------------
# Transformer building blocks (ported verbatim from reference model.py)
# ----------------------------------------------------------------------------
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside sqrt)."""
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class Conv1D(nn.Module):
def __init__(self, nf, nx):
super(Conv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
return x
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=n_embd (nx=n_embd)
assert n_state % config.n_head == 0
self.register_buffer(
"bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)
)
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
def _attn(self, q, k, v):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns - nd:ns, :ns]
w = w * b - 1e10 * (1 - b)
w = nn.Softmax(dim=-1)(w)
return torch.matmul(w, v)
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape)
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape)
if k:
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, x, layer_past=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1]
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value))
a = self._attn(query, key, value)
a = self.merge_heads(a)
a = self.c_proj(a)
return a, present
class MLP(nn.Module):
def __init__(self, n_state, config): # in MLP: n_state=4 * n_embd
super(MLP, self).__init__()
nx = config.n_embd
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
def forward(self, x):
# tanh-approximate GELU, matching the reference HALO implementation.
h = F.gelu(self.c_fc(x), approximate="tanh")
h2 = self.c_proj(h)
return h2
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
nx = config.n_embd
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x, present
class CoarseTransformerModel(nn.Module):
def __init__(self, config):
super(CoarseTransformerModel, self).__init__()
self.n_layer = config.n_layer
self.n_embd = config.n_embd
self.n_vocab = config.total_vocab_size
self.vis_embed_mat = nn.Linear(
config.total_vocab_size, config.n_embd, bias=False
)
self.pos_embed_mat = nn.Embedding(config.n_positions, config.n_embd)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList(
[copy.deepcopy(block) for _ in range(config.n_layer)]
)
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
def forward(self, input_visits, position_ids=None, past=None):
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_visits.size(1) + past_length,
dtype=torch.long,
device=input_visits.device,
)
position_ids = position_ids.unsqueeze(0).expand(
input_visits.size(0), input_visits.size(1)
)
inputs_embeds = self.vis_embed_mat(input_visits)
position_embeds = self.pos_embed_mat(position_ids)
hidden_states = inputs_embeds + position_embeds
for block, layer_past in zip(self.h, past):
hidden_states, _ = block(hidden_states, layer_past)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class AutoregressiveLinear(nn.Linear):
"""Same as Linear except it has a configurable mask on the weights."""
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)
self.register_buffer(
"mask", torch.tril(torch.ones(in_features, out_features)).int()
)
def forward(self, input):
return F.linear(input, self.mask * self.weight, self.bias)
class FineAutoregressiveHead(nn.Module):
def __init__(self, config):
super(FineAutoregressiveHead, self).__init__()
self.auto1 = AutoregressiveLinear(
config.n_embd + config.total_vocab_size,
config.n_embd + config.total_vocab_size,
)
self.auto2 = AutoregressiveLinear(
config.n_embd + config.total_vocab_size,
config.n_embd + config.total_vocab_size,
)
self.n_embd = config.n_embd
self.tot_vocab = config.total_vocab_size
def forward(self, history, input_visits):
history = history[:, :-1, :]
input_visits = input_visits[:, 1:, :]
code_logits = self.auto2(
torch.relu(self.auto1(torch.cat((history, input_visits), dim=2)))
)[:, :, self.n_embd - 1:-1]
return code_logits
def sample(self, history, input_visits):
history = history[:, :-1, :]
input_visits = input_visits[:, 1:, :]
currVisit = torch.cat((history, input_visits), dim=2)[:, -1, :].unsqueeze(1)
code_logits = self.auto2(torch.relu(self.auto1(currVisit)))[
:, :, self.n_embd - 1:-1
]
return code_logits
class HALOModel(nn.Module):
"""Low-level HALO transformer + autoregressive head (ported verbatim)."""
def __init__(self, config):
super(HALOModel, self).__init__()
self.transformer = CoarseTransformerModel(config)
self.ehr_head = FineAutoregressiveHead(config)
def forward(
self,
input_visits,
position_ids=None,
ehr_labels=None,
ehr_masks=None,
past=None,
pos_loss_weight=None,
):
hidden_states = self.transformer(input_visits, position_ids, past)
code_logits = self.ehr_head(hidden_states, input_visits)
sig = nn.Sigmoid()
code_probs = sig(code_logits)
if ehr_labels is not None:
shift_labels = ehr_labels[..., 1:, :].contiguous()
loss_weights = None
if pos_loss_weight is not None:
loss_weights = torch.ones(
code_probs.shape, device=code_probs.device
)
loss_weights = loss_weights + (pos_loss_weight - 1) * shift_labels
if ehr_masks is not None:
code_probs = code_probs * ehr_masks
shift_labels = shift_labels * ehr_masks
if pos_loss_weight is not None:
loss_weights = loss_weights * ehr_masks
bce = nn.BCELoss(weight=loss_weights)
loss = bce(code_probs, shift_labels)
return loss, code_probs, shift_labels
return code_probs
def sample(self, input_visits, random=True):
sig = nn.Sigmoid()
hidden_states = self.transformer(input_visits)
i = 0
while i < self.ehr_head.tot_vocab:
next_logits = self.ehr_head.sample(hidden_states, input_visits)
next_probs = sig(next_logits)
if random:
visit = torch.bernoulli(next_probs)
else:
visit = torch.round(next_probs)
remaining_visit = visit[:, 0, i:]
nonzero = torch.nonzero(remaining_visit, as_tuple=True)[1]
if nonzero.numel() == 0:
break
first_nonzero = nonzero.min()
input_visits[:, -1, i + first_nonzero] = visit[:, 0, i + first_nonzero]
i = i + first_nonzero + 1
return input_visits
# ----------------------------------------------------------------------------
# PyHealth BaseModel wrapper
# ----------------------------------------------------------------------------
[docs]class HALO(BaseModel):
"""HALO synthetic-EHR generator, wrapped as a PyHealth ``BaseModel``.
Trains a GPT-2-style transformer on patient visit-code sequences and
generates synthetic patients by autoregressive sampling. Generation is
**unconditional** (no label conditioning).
The model infers its code vocabulary from the fitted ``SampleDataset``:
``code_vocab_size = dataset.input_processors["visits"].vocab_size()``
(the ``NestedSequenceProcessor`` vocab, which already reserves index 0 for
``<pad>`` and index 1 for ``<unk>``). Three special tokens are appended for
start-of-sequence, end-of-sequence, and pad-visit.
Args:
dataset: A fitted ``SampleDataset`` whose ``input_schema`` contains
``{"visits": NestedSequenceProcessor}`` and whose ``output_schema``
is empty.
embed_dim: Transformer embedding dimension (``n_embd``). Default: 768.
n_heads: Number of attention heads. Must divide ``embed_dim``.
Default: 12.
n_layers: Number of transformer layers. Default: 12.
n_ctx: Maximum number of visit positions (context length). Default: 48.
batch_size: Training batch size. Default: 48.
epochs: Number of training epochs. Default: 50.
pos_loss_weight: Positive-class weight for the BCE loss. ``None`` means
no weighting. Default: None.
lr: Learning rate for the Adam optimizer. Default: 1e-4.
save_dir: Directory for checkpoints written by ``train_model``.
Default: ``"./save/"``.
Examples:
>>> from pyhealth.datasets import create_sample_dataset
>>> samples = [
... {"patient_id": "p1", "visits": [["A", "B"], ["C"]]},
... {"patient_id": "p2", "visits": [["A"], ["B", "C"]]},
... ]
>>> dataset = create_sample_dataset(
... samples=samples,
... input_schema={"visits": "nested_sequence"},
... output_schema={},
... )
>>> model = HALO(dataset, embed_dim=16, n_heads=2, n_layers=2, n_ctx=8)
>>> isinstance(model, HALO)
True
"""
def __init__(
self,
dataset,
embed_dim: int = 768,
n_heads: int = 12,
n_layers: int = 12,
n_ctx: int = 48,
batch_size: int = 48,
epochs: int = 50,
pos_loss_weight: Optional[float] = None,
lr: float = 1e-4,
save_dir: str = "./save/",
) -> None:
super(HALO, self).__init__(dataset)
if "visits" not in dataset.input_processors:
raise ValueError(
"HALO expects an input feature named 'visits' backed by a "
"NestedSequenceProcessor."
)
self.save_dir = save_dir
self._batch_size = batch_size
self._epochs = epochs
self._lr = lr
# Code vocab from the NestedSequenceProcessor (includes <pad>, <unk>).
self.visits_processor = dataset.input_processors["visits"]
code_vocab_size = self.visits_processor.vocab_size()
label_vocab_size = 0 # unconditional generation -- no output labels
# +3 special tokens: start-of-sequence, end-of-sequence, pad-visit.
total_vocab_size = code_vocab_size + label_vocab_size + 3
self.config = HALOConfig(
total_vocab_size=total_vocab_size,
code_vocab_size=code_vocab_size,
label_vocab_size=label_vocab_size,
special_vocab_size=3,
n_positions=n_ctx + 8, # position table needs a little slack
n_ctx=n_ctx,
n_embd=embed_dim,
n_layer=n_layers,
n_head=n_heads,
batch_size=batch_size,
epoch=epochs,
pos_loss_weight=pos_loss_weight,
lr=lr,
)
# Registered as a sub-module so .parameters()/.to() work.
self.halo_model = HALOModel(self.config)
# ------------------------------------------------------------------
# Multi-hot encoding helper
# ------------------------------------------------------------------
def _encode_visits(self, visits: torch.Tensor):
"""Convert a padded index tensor to HALO multi-hot format.
``NestedSequenceProcessor`` returns code indices; the transformer
expects multi-hot vectors of shape ``(batch, n_ctx, total_vocab_size)``
with special tokens. Layout (mirrors the reference): position 0 is the
start token, visits occupy positions 2+, the end token is placed on the
last visit's row, and the pad token fills the remaining positions.
Args:
visits: LongTensor ``(batch, max_visits, max_codes_per_visit)``.
Index 0 is ``<pad>`` and is skipped.
Returns:
batch_ehr: FloatTensor ``(batch, n_ctx, total_vocab_size)``.
batch_mask: FloatTensor ``(batch, n_ctx - 1, 1)``, shifted to align
with the autoregressive prediction targets.
"""
cfg = self.config
batch_size = visits.shape[0]
batch_ehr = torch.zeros(
batch_size, cfg.n_ctx, cfg.total_vocab_size, device=self.device
)
batch_mask = torch.zeros(batch_size, cfg.n_ctx, 1, device=self.device)
start_idx = cfg.code_vocab_size + cfg.label_vocab_size
end_idx = start_idx + 1
pad_idx = start_idx + 2
for i in range(batch_size):
# Count actual (non-padding) visits for this patient.
n_visits = int((visits[i].sum(dim=-1) > 0).sum().item())
n_visits = min(n_visits, cfg.n_ctx - 2)
for j in range(n_visits):
for code_idx in visits[i, j]:
if code_idx > 0: # skip <pad> (index 0)
batch_ehr[i, j + 2, code_idx] = 1
batch_mask[i, j + 2] = 1
batch_ehr[i, 0, start_idx] = 1 # start token
batch_ehr[i, n_visits + 1, end_idx] = 1 # end token (on last visit)
batch_ehr[i, n_visits + 2:, pad_idx] = 1 # pad visits
batch_mask = batch_mask[:, 1:, :] # shift to align with shifted targets
return batch_ehr, batch_mask
# ------------------------------------------------------------------
# forward -- required by BaseModel
# ------------------------------------------------------------------
[docs] def forward(self, visits: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward pass.
Args:
visits: LongTensor ``(batch, max_visits, max_codes_per_visit)`` from
the ``NestedSequenceProcessor``.
**kwargs: Any other batch keys are ignored.
Returns:
Dict with ``loss`` (scalar BCE) and ``y_prob`` (code probabilities,
shape ``(batch, n_ctx - 1, total_vocab_size)``).
"""
visits = visits.to(self.device)
batch_ehr, batch_mask = self._encode_visits(visits)
loss, code_probs, _ = self.halo_model(
batch_ehr,
position_ids=None,
ehr_labels=batch_ehr,
ehr_masks=batch_mask,
pos_loss_weight=self.config.pos_loss_weight,
)
return {"loss": loss, "y_prob": code_probs}
# ------------------------------------------------------------------
# Custom training loop
# ------------------------------------------------------------------
@staticmethod
def _resolve_device(device=None) -> torch.device:
"""Resolve a user-supplied device, defaulting to CUDA when available.
Args:
device: ``None``, a device string (e.g. ``"cuda"``, ``"cuda:1"``,
``"cpu"``), or a ``torch.device``. When ``None``, CUDA is used
if available, otherwise CPU.
"""
if device is None:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(device)
[docs] def train_model(self, train_dataset, val_dataset=None, device=None) -> None:
"""Train the HALO model with a custom loop.
Named ``train_model`` (not ``train``) to avoid shadowing
``nn.Module.train()``. Uses the standard ``get_dataloader`` (which pads
the variable visit dimension for us), an Adam optimizer, and BCE loss.
When ``val_dataset`` is given, validation loss is computed after each
epoch and the best checkpoint is saved to ``self.save_dir``.
Args:
train_dataset: ``SampleDataset`` for training.
val_dataset: Optional ``SampleDataset`` for validation.
device: Device to train on, e.g. ``"cuda"``, ``"cuda:1"``, or
``"cpu"``. If ``None`` (default), uses CUDA when available and
falls back to CPU.
"""
device = self._resolve_device(device)
self.to(device)
print(f"Training on: {device}")
os.makedirs(self.save_dir, exist_ok=True)
optimizer = torch.optim.Adam(self.halo_model.parameters(), lr=self._lr)
checkpoint_path = os.path.join(self.save_dir, "halo_model")
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.halo_model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
train_loader = get_dataloader(
train_dataset, batch_size=self._batch_size, shuffle=True
)
global_loss = 1e10
for epoch in tqdm(range(self._epochs), desc="Epochs"):
self.halo_model.train()
batch_iter = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)
for batch in batch_iter:
visits = batch["visits"].to(self.device)
batch_ehr, batch_mask = self._encode_visits(visits)
optimizer.zero_grad()
loss, _, _ = self.halo_model(
batch_ehr,
position_ids=None,
ehr_labels=batch_ehr,
ehr_masks=batch_mask,
pos_loss_weight=self.config.pos_loss_weight,
)
loss.backward()
optimizer.step()
batch_iter.set_postfix(loss=f"{loss.item():.4f}")
if val_dataset is not None:
self.halo_model.eval()
val_loader = get_dataloader(
val_dataset, batch_size=self._batch_size, shuffle=False
)
val_losses = []
with torch.no_grad():
for val_batch in val_loader:
visits = val_batch["visits"].to(self.device)
batch_ehr, batch_mask = self._encode_visits(visits)
val_loss, _, _ = self.halo_model(
batch_ehr,
position_ids=None,
ehr_labels=batch_ehr,
ehr_masks=batch_mask,
pos_loss_weight=self.config.pos_loss_weight,
)
val_losses.append(val_loss.item())
cur_val_loss = float(np.mean(val_losses))
print(f"Epoch {epoch} Validation Loss: {cur_val_loss:.7f}")
if cur_val_loss < global_loss:
global_loss = cur_val_loss
state = {
"model": self.halo_model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}
torch.save(state, checkpoint_path)
print("------------ Save best model ------------")
# ------------------------------------------------------------------
# Synthesis
# ------------------------------------------------------------------
[docs] def generate(
self, num_samples: int, random_sampling: bool = True, device=None
) -> List[Dict]:
"""Generate synthetic patients using the trained HALO model.
Autoregressive sampling: feed a start token and repeatedly call
``halo_model.sample`` until an end token is produced or ``n_ctx`` steps
are reached, then decode code indices back to code strings.
Args:
num_samples: Number of synthetic patients to generate.
random_sampling: If True, Bernoulli sampling (stochastic). If False,
rounding (deterministic). Default: True.
device: Device to generate on, e.g. ``"cuda"``, ``"cuda:1"``, or
``"cpu"``. If ``None`` (default), uses CUDA when available and
falls back to CPU.
Returns:
List of dicts, each ``{"patient_id": "synthetic_i",
"visits": [[code, ...], ...]}`` with decoded code strings.
"""
device = self._resolve_device(device)
self.to(device)
cfg = self.config
index_to_code = {v: k for k, v in self.visits_processor.code_vocab.items()}
end_token_idx = cfg.code_vocab_size + cfg.label_vocab_size + 1
start_token_idx = cfg.code_vocab_size + cfg.label_vocab_size
self.halo_model.eval()
synthetic_dataset: List[Dict] = []
sample_batch_size = min(num_samples, 256)
generated = 0
pbar = tqdm(total=num_samples, desc="Generating patients")
with torch.no_grad():
while generated < num_samples:
bs = min(sample_batch_size, num_samples - generated)
stoken = torch.zeros(
cfg.total_vocab_size, device=self.device, dtype=torch.float32
)
stoken[start_token_idx] = 1
prev = stoken.unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1)
empty = torch.zeros(
bs, 1, cfg.total_vocab_size,
device=self.device, dtype=torch.float32,
)
for _ in range(cfg.n_ctx - 1):
prev = self.halo_model.sample(
torch.cat((prev, empty), dim=1), random_sampling
)
has_end = prev[:, :, end_token_idx].sum(dim=1).bool()
if has_end.all():
break
batch_ehrs = prev.cpu().detach().numpy()
for i in range(bs):
ehr = batch_ehrs[i] # (seq_len, total_vocab_size)
visits_out: List[List[str]] = []
# Position 0 is the start token; visits occupy positions 1+.
for j in range(1, len(ehr)):
indices = np.nonzero(ehr[j])[0]
visit_codes: List[str] = []
hit_end = False
for idx in indices:
if idx < cfg.code_vocab_size:
code = index_to_code.get(int(idx))
if code not in (None, "<pad>", "<unk>"):
visit_codes.append(code)
elif idx == end_token_idx:
hit_end = True
if visit_codes:
visits_out.append(visit_codes)
if hit_end:
break
synthetic_dataset.append(
{
"patient_id": f"synthetic_{generated + i}",
"visits": visits_out,
}
)
generated += bs
pbar.update(bs)
pbar.close()
return synthetic_dataset