"""Dr. Agent model for PyHealth 2.0.
Author: Joshua Steier
Paper: Dr. Agent: Clinical predictive model via mimicked second opinions
Link: https://doi.org/10.1093/jamia/ocaa074
Description: Multi-agent reinforcement learning model with dynamic skip
connections for clinical prediction tasks. Uses two policy gradient
agents (primary and second-opinion) to capture long-term dependencies
in patient EHR sequences.
"""
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from pyhealth.datasets import SampleDataset
from pyhealth.models.base_model import BaseModel
from pyhealth.models.embedding import EmbeddingModel
from pyhealth.models.utils import get_last_visit
[docs]class AgentLayer(nn.Module):
"""Dr. Agent layer with dual-agent dynamic skip connections.
This layer implements the core mechanism from the Dr. Agent paper:
two policy gradient agents select optimal historical hidden states
to capture long-term dependencies in sequential EHR data.
Args:
input_dim: Input feature dimension.
static_dim: Static feature dimension. If 0, no static features used.
cell: RNN cell type, one of "gru" or "lstm". Default is "gru".
use_baseline: Whether to use baseline for variance reduction in
REINFORCE. Default is True.
n_actions: Number of historical states to consider (K in paper).
Default is 10.
n_units: Hidden units in agent MLPs. Default is 64.
n_hidden: Hidden units in RNN cell. Default is 128.
dropout: Dropout rate applied to final output. Default is 0.5.
lamda: Weight for combining agent-selected state with current state.
h_combined = lamda * h_agent + (1 - lamda) * h_current.
Default is 0.5.
Examples:
>>> from pyhealth.models.agent import AgentLayer
>>> layer = AgentLayer(input_dim=64, static_dim=12)
>>> x = torch.randn(32, 50, 64) # [batch, seq_len, features]
>>> static = torch.randn(32, 12) # [batch, static_dim]
>>> last_out, all_out = layer(x, static=static)
>>> last_out.shape
torch.Size([32, 128])
"""
def __init__(
self,
input_dim: int,
static_dim: int = 0,
cell: str = "gru",
use_baseline: bool = True,
n_actions: int = 10,
n_units: int = 64,
n_hidden: int = 128,
dropout: float = 0.5,
lamda: float = 0.5,
):
super(AgentLayer, self).__init__()
if cell not in ["gru", "lstm"]:
raise ValueError("cell must be 'gru' or 'lstm'")
self.cell = cell
self.use_baseline = use_baseline
self.n_actions = n_actions
self.n_units = n_units
self.input_dim = input_dim
self.n_hidden = n_hidden
self.dropout = dropout
self.lamda = lamda
self.fusion_dim = n_hidden
self.static_dim = static_dim
# Agent state storage (reset each forward pass)
self.agent1_action: List[torch.Tensor] = []
self.agent1_prob: List[torch.Tensor] = []
self.agent1_entropy: List[torch.Tensor] = []
self.agent1_baseline: List[torch.Tensor] = []
self.agent2_action: List[torch.Tensor] = []
self.agent2_prob: List[torch.Tensor] = []
self.agent2_entropy: List[torch.Tensor] = []
self.agent2_baseline: List[torch.Tensor] = []
# Agent 1 (history agent): observes mean of historical hidden states
self.agent1_fc1 = nn.Linear(self.n_hidden + self.static_dim, self.n_units)
self.agent1_fc2 = nn.Linear(self.n_units, self.n_actions)
# Agent 2 (primary agent): observes current input
self.agent2_fc1 = nn.Linear(self.input_dim + self.static_dim, self.n_units)
self.agent2_fc2 = nn.Linear(self.n_units, self.n_actions)
# Baseline networks for variance reduction
if use_baseline:
self.agent1_value = nn.Linear(self.n_units, 1)
self.agent2_value = nn.Linear(self.n_units, 1)
# RNN cell
if self.cell == "lstm":
self.rnn = nn.LSTMCell(self.input_dim, self.n_hidden)
else:
self.rnn = nn.GRUCell(self.input_dim, self.n_hidden)
# Initialize RNN weights
for name, param in self.rnn.named_parameters():
if "bias" in name:
nn.init.constant_(param, 0.0)
elif "weight" in name:
nn.init.orthogonal_(param)
# Dropout layer
if dropout > 0.0:
self.nn_dropout = nn.Dropout(p=dropout)
# Static feature integration layers
if self.static_dim > 0:
self.init_h = nn.Linear(self.static_dim, self.n_hidden)
self.init_c = nn.Linear(self.static_dim, self.n_hidden)
self.fusion = nn.Linear(self.n_hidden + self.static_dim, self.fusion_dim)
self.softmax = nn.Softmax(dim=1)
self.tanh = nn.Tanh()
def _choose_action(
self,
observation: torch.Tensor,
agent: int = 1,
) -> torch.Tensor:
"""Select action (history index) based on observation.
Each agent observes its environment and samples an action from
a categorical distribution over K historical states.
Args:
observation: Environment observation of shape [batch, obs_dim].
agent: Agent identifier (1=history agent, 2=primary agent).
Returns:
Selected action indices of shape [batch, 1].
"""
observation = observation.detach()
if agent == 1:
hidden = self.tanh(self.agent1_fc1(observation))
logits = self.agent1_fc2(hidden)
if self.use_baseline:
self.agent1_baseline.append(self.agent1_value(hidden))
else:
hidden = self.tanh(self.agent2_fc1(observation))
logits = self.agent2_fc2(hidden)
if self.use_baseline:
self.agent2_baseline.append(self.agent2_value(hidden))
probs = self.softmax(logits)
dist = torch.distributions.Categorical(probs)
actions = dist.sample()
if agent == 1:
self.agent1_entropy.append(dist.entropy())
self.agent1_action.append(actions.unsqueeze(-1))
self.agent1_prob.append(dist.log_prob(actions))
else:
self.agent2_entropy.append(dist.entropy())
self.agent2_action.append(actions.unsqueeze(-1))
self.agent2_prob.append(dist.log_prob(actions))
return actions.unsqueeze(-1)
[docs] def forward(
self,
x: torch.Tensor,
static: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward propagation through the Dr. Agent layer.
Args:
x: Input tensor of shape [batch, seq_len, input_dim].
static: Optional static features of shape [batch, static_dim].
mask: Optional mask of shape [batch, seq_len] where True/1
indicates valid timesteps.
Returns:
last_output: Final hidden state of shape [batch, fusion_dim].
all_outputs: All hidden states of shape [batch, seq_len, fusion_dim].
"""
batch_size = x.size(0)
time_step = x.size(1)
# Reset agent state
self.agent1_action = []
self.agent1_prob = []
self.agent1_entropy = []
self.agent1_baseline = []
self.agent2_action = []
self.agent2_prob = []
self.agent2_entropy = []
self.agent2_baseline = []
# Initialize hidden state
if self.static_dim > 0 and static is not None:
cur_h = self.init_h(static)
if self.cell == "lstm":
cur_c = self.init_c(static)
else:
cur_h = torch.zeros(
batch_size, self.n_hidden, dtype=torch.float32, device=x.device
)
if self.cell == "lstm":
cur_c = torch.zeros(
batch_size, self.n_hidden, dtype=torch.float32, device=x.device
)
h_list = []
for t in range(time_step):
cur_input = x[:, t, :]
if t == 0:
# First timestep: initialize history buffer
obs_1 = cur_h
obs_2 = cur_input
if self.static_dim > 0 and static is not None:
obs_1 = torch.cat((obs_1, static), dim=1)
obs_2 = torch.cat((obs_2, static), dim=1)
self._choose_action(obs_1, agent=1)
self._choose_action(obs_2, agent=2)
# Initialize history buffer with zeros
observed_h = (
torch.zeros_like(cur_h)
.view(-1)
.repeat(self.n_actions)
.view(self.n_actions, batch_size, self.n_hidden)
)
action_h = cur_h
if self.cell == "lstm":
observed_c = (
torch.zeros_like(cur_c)
.view(-1)
.repeat(self.n_actions)
.view(self.n_actions, batch_size, self.n_hidden)
)
action_c = cur_c
else:
# Update history buffer (sliding window)
observed_h = torch.cat((observed_h[1:], cur_h.unsqueeze(0)), dim=0)
# Agent observations
obs_1 = observed_h.mean(dim=0)
obs_2 = cur_input
if self.static_dim > 0 and static is not None:
obs_1 = torch.cat((obs_1, static), dim=1)
obs_2 = torch.cat((obs_2, static), dim=1)
# Select actions
act_idx1 = self._choose_action(obs_1, agent=1).long()
act_idx2 = self._choose_action(obs_2, agent=2).long()
# Gather selected hidden states
batch_idx = torch.arange(
batch_size, dtype=torch.long, device=x.device
).unsqueeze(-1)
action_h1 = observed_h[act_idx1, batch_idx, :].squeeze(1)
action_h2 = observed_h[act_idx2, batch_idx, :].squeeze(1)
action_h = (action_h1 + action_h2) / 2
if self.cell == "lstm":
observed_c = torch.cat(
(observed_c[1:], cur_c.unsqueeze(0)), dim=0
)
action_c1 = observed_c[act_idx1, batch_idx, :].squeeze(1)
action_c2 = observed_c[act_idx2, batch_idx, :].squeeze(1)
action_c = (action_c1 + action_c2) / 2
# Combine agent-selected state with current state
weighted_h = self.lamda * action_h + (1 - self.lamda) * cur_h
if self.cell == "lstm":
weighted_c = self.lamda * action_c + (1 - self.lamda) * cur_c
cur_h, cur_c = self.rnn(cur_input, (weighted_h, weighted_c))
else:
cur_h = self.rnn(cur_input, weighted_h)
h_list.append(cur_h)
# Stack all hidden states
all_outputs = torch.stack(h_list, dim=1)
# Fuse with static features if available
if self.static_dim > 0 and static is not None:
static_expanded = static.unsqueeze(1).expand(-1, time_step, -1)
all_outputs = torch.cat((all_outputs, static_expanded), dim=2)
all_outputs = self.fusion(all_outputs)
# Get last valid output
last_output = get_last_visit(all_outputs, mask)
if self.dropout > 0.0:
last_output = self.nn_dropout(last_output)
return last_output, all_outputs
[docs]class Agent(BaseModel):
"""Dr. Agent model for clinical prediction tasks.
This model uses two reinforcement learning agents with dynamic skip
connections to capture long-term dependencies in patient EHR sequences.
The primary agent focuses on current health status while the second-opinion
agent considers historical context.
Paper: Junyi Gao et al. Dr. Agent: Clinical predictive model via mimicked
second opinions. JAMIA 2020.
Args:
dataset: SampleDataset with fitted input/output processors.
embedding_dim: Embedding dimension for input features. Default is 128.
hidden_dim: Hidden dimension for RNN and output. Default is 128.
static_key: Key for static features (e.g., demographics). These are
passed directly to AgentLayer, not through EmbeddingModel.
Default is None.
use_baseline: Whether to use baseline for RL variance reduction.
Default is True.
**kwargs: Additional arguments passed to AgentLayer (e.g., n_actions,
n_units, dropout, lamda, cell).
Example:
>>> from pyhealth.datasets import create_sample_dataset
>>> from pyhealth.models import Agent
>>> samples = [
... {
... "patient_id": "p0",
... "visit_id": "v0",
... "conditions": [["A01", "A02"], ["B01"]],
... "procedures": [["P1"], ["P2", "P3"]],
... "demographic": [65.0, 1.0, 25.5],
... "label": 1,
... },
... {
... "patient_id": "p1",
... "visit_id": "v1",
... "conditions": [["C01"]],
... "procedures": [["P4"]],
... "demographic": [45.0, 0.0, 22.1],
... "label": 0,
... },
... ]
>>> dataset = create_sample_dataset(
... samples=samples,
... input_schema={"conditions": "nested_sequence", "procedures": "nested_sequence"},
... output_schema={"label": "binary"},
... dataset_name="test",
... )
>>> model = Agent(dataset, static_key="demographic")
>>> # Forward pass with a batch
>>> from pyhealth.datasets import get_dataloader
>>> loader = get_dataloader(dataset, batch_size=2, shuffle=False)
>>> batch = next(iter(loader))
>>> output = model(**batch)
>>> output["loss"].backward()
"""
def __init__(
self,
dataset: SampleDataset,
embedding_dim: int = 128,
hidden_dim: int = 128,
static_key: Optional[str] = None,
use_baseline: bool = True,
**kwargs,
):
super(Agent, self).__init__(dataset=dataset)
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.static_key = static_key
self.use_baseline = use_baseline
# Validate kwargs
if "input_dim" in kwargs:
raise ValueError("input_dim is determined by embedding_dim")
if "n_hidden" in kwargs:
raise ValueError("n_hidden is determined by hidden_dim")
# Single label key required
assert len(self.label_keys) == 1, "Only one label key is supported"
self.label_key = self.label_keys[0]
self.mode = self.dataset.output_schema[self.label_key]
# Determine static dimension
self.static_dim = 0
if self.static_key is not None:
first_sample=self.dataset[0]
if self.static_key in first_sample:
static_val=first_sample[self.static_key]
if isinstance(static_val, torch.Tensor):
self.static_dim = static_val.shape[-1]
else:
self.static_dim = len(static_val)
# Sequence feature keys (exclude static)
self.seq_feature_keys = [
k for k in self.feature_keys if k != self.static_key
]
# Embedding model for sequence features
self.embedding_model = EmbeddingModel(dataset, embedding_dim)
# Agent layer for each sequence feature
self.agent = nn.ModuleDict()
for feature_key in self.seq_feature_keys:
self.agent[feature_key] = AgentLayer(
input_dim=embedding_dim,
static_dim=self.static_dim,
n_hidden=hidden_dim,
use_baseline=use_baseline,
**kwargs,
)
# Output layer
output_size = self.get_output_size()
self.fc = nn.Linear(len(self.seq_feature_keys) * hidden_dim, output_size)
def _compute_rl_loss(
self,
agent_layer: AgentLayer,
pred: torch.Tensor,
true: torch.Tensor,
mask: torch.Tensor,
gamma: float = 0.9,
entropy_term: float = 0.01,
) -> torch.Tensor:
"""Compute REINFORCE loss for agent optimization.
Args:
agent_layer: AgentLayer instance with stored action probabilities.
pred: Predicted logits of shape [batch, output_size].
true: Ground truth labels.
mask: Valid timestep mask of shape [batch, seq_len].
gamma: Discount factor for long-term rewards. Default is 0.9.
entropy_term: Entropy bonus coefficient. Default is 0.01.
Returns:
Combined RL loss (policy loss + value loss if using baseline).
"""
# Compute rewards based on prediction accuracy
if self.mode == "binary":
pred_prob = torch.sigmoid(pred)
#reward=1-[true-pred_prob]
#High reward when prediction matches label for both classes
rewards = (1 - torch.abs(true.float() - pred_prob)).squeeze(dim=-1)
elif self.mode == "multiclass":
pred_prob = torch.softmax(pred, dim=-1)
y_onehot = torch.zeros_like(pred_prob).scatter(1, true.unsqueeze(1), 1)
rewards = (pred_prob * y_onehot).sum(dim=-1)
elif self.mode == "multilabel":
pred_prob = torch.sigmoid(pred)
# Reward based on how well predictions match all labels
rewards = (1 - torch.abs(true.float() - pred_prob)).mean(dim=-1)
else:
raise ValueError(f"Unsupported mode: {self.mode}")
# Stack agent log probabilities and entropy
act_prob1 = torch.stack(agent_layer.agent1_prob).permute(1, 0)
act_prob1 = act_prob1.to(self.device) * mask.float()
act_entropy1 = torch.stack(agent_layer.agent1_entropy).permute(1, 0)
act_entropy1 = act_entropy1.to(self.device) * mask.float()
act_prob2 = torch.stack(agent_layer.agent2_prob).permute(1, 0)
act_prob2 = act_prob2.to(self.device) * mask.float()
act_entropy2 = torch.stack(agent_layer.agent2_entropy).permute(1, 0)
act_entropy2 = act_entropy2.to(self.device) * mask.float()
if self.use_baseline:
act_baseline1 = (
torch.stack(agent_layer.agent1_baseline)
.squeeze(-1)
.permute(1, 0)
.to(self.device)
)
act_baseline1 = act_baseline1 * mask.float()
act_baseline2 = (
torch.stack(agent_layer.agent2_baseline)
.squeeze(-1)
.permute(1, 0)
.to(self.device)
)
act_baseline2 = act_baseline2 * mask.float()
# Compute discounted cumulative rewards
seq_len = act_prob1.size(1)
running_rewards = []
discounted_reward = torch.zeros_like(rewards)
for i in reversed(range(seq_len)):
if i == seq_len - 1:
discounted_reward = rewards + gamma * discounted_reward
else:
discounted_reward = gamma * discounted_reward
running_rewards.insert(0, discounted_reward)
rewards_tensor = torch.stack(running_rewards).permute(1, 0).detach()
# Compute losses
mask_sum = torch.sum(mask.float(), dim=1).clamp(min=1.0)
if self.use_baseline:
# Value function loss
loss_value1 = torch.mean(
torch.sum((rewards_tensor - act_baseline1) ** 2, dim=1) / mask_sum
)
loss_value2 = torch.mean(
torch.sum((rewards_tensor - act_baseline2) ** 2, dim=1) / mask_sum
)
# Policy gradient loss with baseline
advantage1 = rewards_tensor - act_baseline1
advantage2 = rewards_tensor - act_baseline2
loss_rl1 = torch.mean(
-torch.sum(
act_prob1 * advantage1 + entropy_term * act_entropy1, dim=1
)
/ mask_sum
)
loss_rl2 = torch.mean(
-torch.sum(
act_prob2 * advantage2 + entropy_term * act_entropy2, dim=1
)
/ mask_sum
)
return loss_rl1 + loss_rl2 + loss_value1 + loss_value2
else:
# Policy gradient loss without baseline
loss_rl1 = torch.mean(
-torch.sum(
act_prob1 * rewards_tensor + entropy_term * act_entropy1, dim=1
)
/ mask_sum
)
loss_rl2 = torch.mean(
-torch.sum(
act_prob2 * rewards_tensor + entropy_term * act_entropy2, dim=1
)
/ mask_sum
)
return loss_rl1 + loss_rl2
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation.
Args:
**kwargs: Keyword arguments containing input features and labels.
Must include all keys from input_schema and output_schema.
Returns:
Dictionary containing:
- loss: Combined task loss and RL loss.
- y_prob: Predicted probabilities.
- y_true: Ground truth labels.
- logit: Raw logits.
- embed (optional): Patient embeddings if embed=True in kwargs.
"""
patient_emb = []
mask_dict = {}
# Get static features
static = None
if self.static_key is not None and self.static_key in kwargs:
static_data = kwargs[self.static_key]
if isinstance(static_data, torch.Tensor):
static = static_data.float().to(self.device)
else:
static = torch.tensor(
static_data, dtype=torch.float, device=self.device
)
# Get embeddings for sequence features
embedded = self.embedding_model(kwargs)
# Process each sequence feature through its agent
for feature_key in self.seq_feature_keys:
x=embedded[feature_key]
#Handle NestedSequenceProcessor output(4D->3D)
#NestedSequenceProcessor returns [batch, visits, codes, embed_dim]
#AgentLayer expects [batch, seq_len, input_dim]
if x.dim()== 4:
#Sum across the codes dimension(like RETAIN, AdaCare, SafeDrug)
x=torch.sum(x, dim=2) #-> [batch, visits, embed_dim]
elif x.dim()!= 3:
raise ValueError(
f"Expected 3D or 4D tensor for {feature_key}, got {x.dim()}D"
)
#Compute visit-level mask (non-zero entries are valid)
mask=x.abs().sum(dim=-1)> 0
mask_dict[feature_key]= mask
# Forward through agent layer
out, _ = self.agent[feature_key](x, static=static, mask=mask)
patient_emb.append(out)
# Concatenate embeddings and predict
patient_emb = torch.cat(patient_emb, dim=1)
logits = self.fc(patient_emb)
# Compute task loss
y_true = kwargs[self.label_key]
if not isinstance(y_true, torch.Tensor):
y_true = torch.tensor(y_true, device=self.device)
y_true = y_true.to(self.device)
loss_task = self.get_loss_function()(logits, y_true)
# Compute RL loss for each agent
loss_rl = torch.tensor(0.0, device=self.device)
for feature_key in self.seq_feature_keys:
loss_rl = loss_rl + self._compute_rl_loss(
self.agent[feature_key],
logits,
y_true,
mask_dict[feature_key],
)
loss = loss_task + loss_rl
y_prob = self.prepare_y_prob(logits)
results = {
"loss": loss,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
}
if kwargs.get("embed", False):
results["embed"] = patient_emb
return results
if __name__ == "__main__":
from pyhealth.datasets import create_sample_dataset, get_dataloader
# Example usage with synthetic data
samples = [
{
"patient_id": "patient-0",
"visit_id": "visit-0",
"conditions": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]],
"procedures": [["P1", "P2"], ["P3"]],
"demographic": [1.0, 2.0, 1.3],
"label": 1,
},
{
"patient_id": "patient-0",
"visit_id": "visit-1",
"conditions": [["A04A", "B035", "C129"]],
"procedures": [["P4", "P5"]],
"demographic": [1.0, 2.0, 1.3],
"label": 0,
},
]
dataset = create_sample_dataset(
samples=samples,
input_schema={"conditions": "nested_sequence", "procedures": "nested_sequence"},
output_schema={"label": "binary"},
dataset_name="test",
)
train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
model = Agent(
dataset=dataset,
static_key="demographic",
embedding_dim=128,
hidden_dim=128,
)
data_batch = next(iter(train_loader))
ret = model(**data_batch)
print(ret)
ret["loss"].backward()