Source code for pyhealth.models.agent

from typing import List, Tuple, Dict, Optional

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils

from pyhealth.datasets import SampleEHRDataset
from pyhealth.models import BaseModel
from pyhealth.models.utils import get_last_visit
import torch.nn.functional as F


[docs]class AgentLayer(nn.Module): """Dr. Agent layer. Paper: Junyi Gao et al. Dr. Agent: Clinical predictive model via mimicked second opinions. JAMIA. This layer is used in the Dr. Agent model. But it can also be used as a standalone layer. Args: input_dim: dynamic feature size. static_dim: static feature size, if 0, then no static feature is used. cell: rnn cell type. Default is "gru". use_baseline: whether to use baseline for the RL agent. Default is True. n_actions: number of historical visits to choose. Default is 10. n_units: number of hidden units in each agent. Default is 64. fusion_dim: number of hidden units in the final representation. Default is 128. n_hidden: number of hidden units in the rnn. Default is 128. dropout: dropout rate. Default is 0.5. lamda: weight for the agent selected hidden state and the current hidden state. Default is 0.5. Examples: >>> from pyhealth.models import AgentLayer >>> input = torch.randn(3, 128, 64) # [batch size, sequence len, feature_size] >>> layer = AgentLayer(64) >>> c, _ = layer(input) >>> c.shape torch.Size([3, 128]) """ def __init__( self, input_dim: int, static_dim: int = 0, cell: str = "gru", use_baseline: bool = True, n_actions: int = 10, n_units: int = 64, n_hidden: int = 128, dropout: int = 0.5, lamda: int = 0.5, ): super(AgentLayer, self).__init__() if cell not in ["gru", "lstm"]: raise ValueError("Only gru and lstm are supported for cell.") self.cell = cell self.use_baseline = use_baseline self.n_actions = n_actions self.n_units = n_units self.input_dim = input_dim self.n_hidden = n_hidden # self.n_output = n_output self.dropout = dropout self.lamda = lamda self.fusion_dim = n_hidden self.static_dim = static_dim self.agent1_action = [] self.agent1_prob = [] self.agent1_entropy = [] self.agent1_baseline = [] self.agent2_action = [] self.agent2_prob = [] self.agent2_entropy = [] self.agent2_baseline = [] self.agent1_fc1 = nn.Linear(self.n_hidden + self.static_dim, self.n_units) self.agent2_fc1 = nn.Linear(self.input_dim + self.static_dim, self.n_units) self.agent1_fc2 = nn.Linear(self.n_units, self.n_actions) self.agent2_fc2 = nn.Linear(self.n_units, self.n_actions) if use_baseline == True: self.agent1_value = nn.Linear(self.n_units, 1) self.agent2_value = nn.Linear(self.n_units, 1) if self.cell == "lstm": self.rnn = nn.LSTMCell(self.input_dim, self.n_hidden) else: self.rnn = nn.GRUCell(self.input_dim, self.n_hidden) for name, param in self.rnn.named_parameters(): if "bias" in name: nn.init.constant_(param, 0.0) elif "weight" in name: nn.init.orthogonal_(param) if dropout > 0.0: self.nn_dropout = nn.Dropout(p=dropout) if self.static_dim > 0: self.init_h = nn.Linear(self.static_dim, self.n_hidden) self.init_c = nn.Linear(self.static_dim, self.n_hidden) self.fusion = nn.Linear(self.n_hidden + self.static_dim, self.fusion_dim) # self.output = nn.Linear(self.fusion_dim, self.n_output) self.sigmoid = nn.Sigmoid() self.softmax = nn.Softmax(dim=1) self.tanh = nn.Tanh() self.relu = nn.ReLU()
[docs] def choose_action(self, observation, agent=1): observation = observation.detach() if agent == 1: result_fc1 = self.agent1_fc1(observation) result_fc1 = self.tanh(result_fc1) result_fc2 = self.agent1_fc2(result_fc1) if self.use_baseline == True: result_value = self.agent1_value(result_fc1) self.agent1_baseline.append(result_value) else: result_fc1 = self.agent2_fc1(observation) result_fc1 = self.tanh(result_fc1) result_fc2 = self.agent2_fc2(result_fc1) if self.use_baseline == True: result_value = self.agent2_value(result_fc1) self.agent2_baseline.append(result_value) probs = self.softmax(result_fc2) m = torch.distributions.Categorical(probs) actions = m.sample() if agent == 1: self.agent1_entropy.append(m.entropy()) self.agent1_action.append(actions.unsqueeze(-1)) self.agent1_prob.append(m.log_prob(actions)) else: self.agent2_entropy.append(m.entropy()) self.agent2_action.append(actions.unsqueeze(-1)) self.agent2_prob.append(m.log_prob(actions)) return actions.unsqueeze(-1)
[docs] def forward( self, x: torch.tensor, static: Optional[torch.tensor] = None, mask: Optional[torch.tensor] = None, ) -> Tuple[torch.tensor]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, input_dim]. static: a tensor of shape [batch size, static_dim]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Returns: last_output: a tensor of shape [batch size, n_hidden] representing the patient embedding. output: a tensor of shape [batch size, sequence len, n_hidden] representing the patient embedding at each time step. """ # rnn will only apply dropout between layers batch_size = x.size(0) time_step = x.size(1) feature_dim = x.size(2) self.agent1_action = [] self.agent1_prob = [] self.agent1_entropy = [] self.agent1_baseline = [] self.agent2_action = [] self.agent2_prob = [] self.agent2_entropy = [] self.agent2_baseline = [] if self.static_dim > 0: cur_h = self.init_h(static) if self.cell == "lstm": cur_c = self.init_c(static) else: cur_h = torch.zeros( batch_size, self.n_hidden, dtype=torch.float32, device=x.device ) if self.cell == "lstm": cur_c = torch.zeros( batch_size, self.n_hidden, dtype=torch.float32, device=x.device ) h = [] for cur_time in range(time_step): cur_input = x[:, cur_time, :] if cur_time == 0: obs_1 = cur_h obs_2 = cur_input if self.static_dim > 0: obs_1 = torch.cat((obs_1, static), dim=1) obs_2 = torch.cat((obs_2, static), dim=1) self.choose_action(obs_1, 1).long() self.choose_action(obs_2, 2).long() observed_h = ( torch.zeros_like(cur_h, dtype=torch.float32) .view(-1) .repeat(self.n_actions) .view(self.n_actions, batch_size, self.n_hidden) ) action_h = cur_h if self.cell == "lstm": observed_c = ( torch.zeros_like(cur_c, dtype=torch.float32) .view(-1) .repeat(self.n_actions) .view(self.n_actions, batch_size, self.n_hidden) ) action_c = cur_c else: observed_h = torch.cat((observed_h[1:], cur_h.unsqueeze(0)), 0) obs_1 = observed_h.mean(dim=0) obs_2 = cur_input if self.static_dim > 0: obs_1 = torch.cat((obs_1, static), dim=1) obs_2 = torch.cat((obs_2, static), dim=1) act_idx1 = self.choose_action(obs_1, 1).long() act_idx2 = self.choose_action(obs_2, 2).long() batch_idx = torch.arange(batch_size, dtype=torch.long).unsqueeze(-1) action_h1 = observed_h[act_idx1, batch_idx, :].squeeze(1) action_h2 = observed_h[act_idx2, batch_idx, :].squeeze(1) action_h = (action_h1 + action_h2) / 2 if self.cell == "lstm": observed_c = torch.cat((observed_c[1:], cur_c.unsqueeze(0)), 0) action_c1 = observed_c[act_idx1, batch_idx, :].squeeze(1) action_c2 = observed_c[act_idx2, batch_idx, :].squeeze(1) action_c = (action_c1 + action_c2) / 2 if self.cell == "lstm": weighted_h = self.lamda * action_h + (1 - self.lamda) * cur_h weighted_c = self.lamda * action_c + (1 - self.lamda) * cur_c rnn_state = (weighted_h, weighted_c) cur_h, cur_c = self.rnn(cur_input, rnn_state) else: weighted_h = self.lamda * action_h + (1 - self.lamda) * cur_h cur_h = self.rnn(cur_input, weighted_h) h.append(cur_h) h = torch.stack(h, dim=1) if self.static_dim > 0: static = static.unsqueeze(1).repeat(1, time_step, 1) h = torch.cat((h, static), dim=2) h = self.fusion(h) last_out = get_last_visit(h, mask) if self.dropout > 0.0: last_out = self.nn_dropout(last_out) return last_out, h
[docs]class Agent(BaseModel): """Dr. Agent model. Paper: Junyi Gao et al. Dr. Agent: Clinical predictive model via mimicked second opinions. JAMIA. Note: We use separate Dr. Agent layers for different feature_keys. Currently, we automatically support different input formats: - code based input (need to use the embedding table later) - float/int based value input We follow the current convention for the Dr. Agent model: - case 1. [code1, code2, code3, ...] - we will assume the code follows the order; our model will encode each code into a vector and apply Dr. Agent on the code level - case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], ...] - we will assume the inner bracket follows the order; our model first use the embedding table to encode each code into a vector and then use average/mean pooling to get one vector for one inner bracket; then use Dr. Agent one the braket level - case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run Dr. Agent directly on the inner bracket level, similar to case 1 after embedding table - case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run Dr. Agent directly on the inner bracket level, similar to case 2 after embedding table Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". static_keys: the key in samples to use as static features, e.g. "demographics". Default is None. we only support numerical static features. embedding_dim: the embedding dimension. Default is 128. hidden_dim: the hidden dimension of the RNN in the Dr. Agent layer. Default is 128. use_baseline: whether to use the baseline value to calculate the RL loss. Default is True. **kwargs: other parameters for the Dr. Agent layer. Examples: >>> from pyhealth.datasets import SampleEHRDataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 ... "list_list_vectors": [ ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], ... [[7.7, 8.5, 9.4]], ... ], ... "demographic": [0.0, 2.0, 1.5], ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "list_codes": [ ... "55154191800", ... "551541928", ... "55154192800", ... "705182798", ... "70518279800", ... ], ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], ... "list_list_codes": [["A04A", "B035", "C129"]], ... "list_list_vectors": [ ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ... ], ... "demographic": [0.0, 2.0, 1.5], ... "label": 0, ... }, ... ] >>> dataset = SampleEHRDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import Agent >>> model = Agent( ... dataset=dataset, ... feature_keys=[ ... "list_codes", ... "list_vectors", ... "list_list_codes", ... "list_list_vectors", ... ], ... label_key="label", ... static_key="demographic", ... mode="binary" ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(1.4059, grad_fn=<AddBackward0>), 'y_prob': tensor([[0.4861], [0.5348]], grad_fn=<SigmoidBackward0>), 'y_true': tensor([[0.], [1.]]), 'logit': tensor([[-0.0556], [0.1392]], grad_fn=<AddmmBackward0>) } >>> """ def __init__( self, dataset: SampleEHRDataset, feature_keys: List[str], label_key: str, mode: str, static_key: Optional[str] = None, embedding_dim: int = 128, hidden_dim: int = 128, use_baseline: bool = True, **kwargs, ): super(Agent, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, mode=mode, ) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim # validate kwargs for Dr. Agent layer if "feature_size" in kwargs: raise ValueError("feature_size is determined by embedding_dim") # the key of self.feat_tokenizers only contains the code based inputs self.feat_tokenizers = {} self.static_key = static_key self.use_baseline = use_baseline self.label_tokenizer = self.get_label_tokenizer() # the key of self.embeddings only contains the code based inputs self.embeddings = nn.ModuleDict() # the key of self.linear_layers only contains the float/int based inputs self.linear_layers = nn.ModuleDict() self.static_dim = 0 if self.static_key is not None: self.static_dim = self.dataset.input_info[self.static_key]["len"] self.agent = nn.ModuleDict() # add feature Dr. Agent layers for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] # sanity check if input_info["type"] not in [str, float, int]: raise ValueError( "Dr. Agent only supports str code, float and int as input types" ) elif (input_info["type"] == str) and (input_info["dim"] not in [2, 3]): raise ValueError( "Dr. Agent only supports 2-dim or 3-dim str code as input types" ) elif (input_info["type"] in [float, int]) and ( input_info["dim"] not in [2, 3] ): raise ValueError( "Dr. Agent only supports 2-dim or 3-dim float and int as input types" ) # for code based input, we need Type # for float/int based input, we need Type, input_dim self.add_feature_transform_layer(feature_key, input_info) self.agent[feature_key] = AgentLayer( input_dim=embedding_dim, static_dim=self.static_dim, n_hidden=hidden_dim, **kwargs, ) output_size = self.get_output_size(self.label_tokenizer) self.fc = nn.Linear(len(self.feature_keys) * self.hidden_dim, output_size)
[docs] def get_loss(self, model, pred, true, mask, gamma=0.9, entropy_term=0.01): if self.mode == "binary": pred = torch.sigmoid(pred) rewards = ((pred - 0.5) * 2 * true).squeeze() elif self.mode == "multiclass": pred = torch.softmax(pred, dim=-1) y_onehot = torch.zeros_like(pred).scatter(1, true.unsqueeze(1), 1) rewards = (pred * y_onehot).sum(-1).squeeze() elif self.mode == "multilabel": pred = torch.sigmoid(pred) rewards = ( ((pred - 0.5) * 2 * true).sum(dim=-1) / (true.sum(dim=-1) + 1e-7) ).squeeze() elif self.mode == "regression": rewards = (1 / torch.abs(pred - true)).squeeze() # b*t rewards = torch.clamp(rewards, min=0, max=5) else: raise ValueError( "mode should be binary, multiclass, multilabel or regression" ) act_prob1 = model.agent1_prob act_prob1 = torch.stack(act_prob1).permute(1, 0).to(self.device) act_prob1 = act_prob1 * mask.view(act_prob1.size(0), act_prob1.size(1)) act_entropy1 = model.agent1_entropy act_entropy1 = torch.stack(act_entropy1).permute(1, 0).to(self.device) act_entropy1 = act_entropy1 * mask.view( act_entropy1.size(0), act_entropy1.size(1) ) if self.use_baseline == True: act_baseline1 = model.agent1_baseline act_baseline1 = ( torch.stack(act_baseline1).squeeze(-1).permute(1, 0).to(self.device) ) act_baseline1 = act_baseline1 * mask.view( act_baseline1.size(0), act_baseline1.size(1) ) act_prob2 = model.agent2_prob act_prob2 = torch.stack(act_prob2).permute(1, 0).to(self.device) act_prob2 = act_prob2 * mask.view(act_prob2.size(0), act_prob2.size(1)) act_entropy2 = model.agent2_entropy act_entropy2 = torch.stack(act_entropy2).permute(1, 0).to(self.device) act_entropy2 = act_entropy2 * mask.view( act_entropy2.size(0), act_entropy2.size(1) ) if self.use_baseline == True: act_baseline2 = model.agent2_baseline act_baseline2 = ( torch.stack(act_baseline2).squeeze(-1).permute(1, 0).to(self.device) ) act_baseline2 = act_baseline2 * mask.view( act_baseline2.size(0), act_baseline2.size(1) ) running_rewards = [] discounted_rewards = 0 for i in reversed(range(act_prob1.size(1))): if i == act_prob1.size(1) - 1: discounted_rewards = rewards + gamma * discounted_rewards else: discounted_rewards = ( torch.zeros_like(rewards) + gamma * discounted_rewards ) running_rewards.insert(0, discounted_rewards) rewards = torch.stack(running_rewards).permute(1, 0) # rewards = (rewards - rewards.mean(dim=1).unsqueeze(-1)) / ( # rewards.std(dim=1) + 1e-7 # ).unsqueeze(-1) rewards = rewards.detach() if self.use_baseline == True: loss_value1 = torch.sum((rewards - act_baseline1) ** 2, dim=1) / torch.sum( mask, dim=1 ) loss_value1 = torch.mean(loss_value1) loss_value2 = torch.sum((rewards - act_baseline2) ** 2, dim=1) / torch.sum( mask, dim=1 ) loss_value2 = torch.mean(loss_value2) loss_value = loss_value1 + loss_value2 loss_RL1 = -torch.sum( act_prob1 * (rewards - act_baseline1) + entropy_term * act_entropy1, dim=1, ) / torch.sum(mask, dim=1) loss_RL1 = torch.mean(loss_RL1) loss_RL2 = -torch.sum( act_prob2 * (rewards - act_baseline2) + entropy_term * act_entropy2, dim=1, ) / torch.sum(mask, dim=1) loss_RL2 = torch.mean(loss_RL2) loss_RL = loss_RL1 + loss_RL2 loss = loss_RL + loss_value else: loss_RL1 = -torch.sum( act_prob1 * rewards + entropy_term * act_entropy1, dim=1 ) / torch.sum(mask, dim=1) loss_RL1 = torch.mean(loss_RL1) loss_RL2 = -torch.sum( act_prob2 * rewards + entropy_term * act_entropy2, dim=1 ) / torch.sum(mask, dim=1) loss_RL2 = torch.mean(loss_RL2) loss_RL = loss_RL1 + loss_RL2 loss = loss_RL return loss
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. The label `kwargs[self.label_key]` is a list of labels for each patient. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Returns: A dictionary with the following keys: loss: a scalar tensor representing the final loss. loss_task: a scalar tensor representing the task loss. loss_RL: a scalar tensor representing the RL loss. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels. """ patient_emb = [] mask_dict = {} for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] dim_, type_ = input_info["dim"], input_info["type"] # for case 1: [code1, code2, code3, ...] if (dim_ == 2) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_2d( kwargs[feature_key] ) # (patient, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, event) mask = torch.any(x !=0, dim=2) mask_dict[feature_key] = mask # for case 2: [[code1, code2], [code3, ...], ...] elif (dim_ == 3) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_3d( kwargs[feature_key] ) # (patient, visit, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, visit, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) # (patient, visit) mask = torch.any(x !=0, dim=2) mask_dict[feature_key] = mask # for case 3: [[1.5, 2.0, 0.0], ...] elif (dim_ == 2) and (type_ in [float, int]): x, mask = self.padding2d(kwargs[feature_key]) # (patient, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, event, embedding_dim) x = self.linear_layers[feature_key](x) # (patient, event) mask = mask.bool().to(self.device) mask_dict[feature_key] = mask # for case 4: [[[1.5, 2.0, 0.0], [1.8, 2.4, 6.0]], ...] elif (dim_ == 3) and (type_ in [float, int]): x, mask = self.padding3d(kwargs[feature_key]) # (patient, visit, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) x = self.linear_layers[feature_key](x) # (patient, event) mask = mask[:, :, 0] mask = mask.bool().to(self.device) mask_dict[feature_key] = mask else: raise NotImplementedError if self.static_dim > 0: static = torch.tensor( kwargs[self.static_key], dtype=torch.float, device=self.device ) x, _ = self.agent[feature_key](x, static=static, mask=mask) else: x, _ = self.agent[feature_key](x, mask=mask) patient_emb.append(x) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) loss_task = self.get_loss_function()(logits, y_true) loss_rl = 0 for feature_key in self.feature_keys: cur_loss = self.get_loss( self.agent[feature_key], logits, y_true, mask_dict[feature_key] ) loss_rl += cur_loss loss = loss_task + loss_rl y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = patient_emb return results
if __name__ == "__main__": from pyhealth.datasets import SampleEHRDataset samples = [ { "patient_id": "patient-0", "visit_id": "visit-0", # "single_vector": [1, 2, 3], "list_codes": ["505800458", "50580045810", "50580045811"], # NDC "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 "list_list_vectors": [ [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], [[7.7, 8.5, 9.4]], ], "label": 1, "demographic": [1.0, 2.0, 1.3], }, { "patient_id": "patient-0", "visit_id": "visit-1", # "single_vector": [1, 5, 8], "list_codes": [ "55154191800", "551541928", "55154192800", "705182798", "70518279800", ], "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], "list_list_codes": [["A04A", "B035", "C129"]], "list_list_vectors": [ [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ], "label": 0, "demographic": [1.0, 2.0, 1.3], }, ] # dataset dataset = SampleEHRDataset(samples=samples, dataset_name="test") # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) # model model = Agent( dataset=dataset, feature_keys=[ "list_codes", "list_vectors", "list_list_codes", # "list_list_vectors", ], static_key="demographic", label_key="label", mode="binary", ) # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()