Source code for pyhealth.models.stagenet

from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils

from pyhealth.datasets import SampleEHRDataset
from pyhealth.models import BaseModel
from pyhealth.models.utils import get_last_visit

# VALID_OPERATION_LEVEL = ["visit", "event"]


[docs]class StageNetLayer(nn.Module): """StageNet layer. Paper: Stagenet: Stage-aware neural networks for health risk prediction. WWW 2020. This layer is used in the StageNet model. But it can also be used as a standalone layer. Args: input_dim: dynamic feature size. chunk_size: the chunk size for the StageNet layer. Default is 128. levels: the number of levels for the StageNet layer. levels * chunk_size = hidden_dim in the RNN. Smaller chunk size and more levels can capture more detailed patient status variations. Default is 3. conv_size: the size of the convolutional kernel. Default is 10. dropconnect: the dropout rate for the dropconnect. Default is 0.3. dropout: the dropout rate for the dropout. Default is 0.3. dropres: the dropout rate for the residual connection. Default is 0.3. Examples: >>> from pyhealth.models import StageNetLayer >>> input = torch.randn(3, 128, 64) # [batch size, sequence len, feature_size] >>> layer = StageNetLayer(64) >>> c, _, _ = layer(input) >>> c.shape torch.Size([3, 384]) """ def __init__( self, input_dim: int, chunk_size: int = 128, conv_size: int = 10, levels: int = 3, dropconnect: int = 0.3, dropout: int = 0.3, dropres: int = 0.3, ): super(StageNetLayer, self).__init__() self.dropout = dropout self.dropconnect = dropconnect self.dropres = dropres self.input_dim = input_dim self.hidden_dim = chunk_size * levels self.conv_dim = self.hidden_dim self.conv_size = conv_size # self.output_dim = output_dim self.levels = levels self.chunk_size = chunk_size self.kernel = nn.Linear( int(input_dim + 1), int(self.hidden_dim * 4 + levels * 2) ) nn.init.xavier_uniform_(self.kernel.weight) nn.init.zeros_(self.kernel.bias) self.recurrent_kernel = nn.Linear( int(self.hidden_dim + 1), int(self.hidden_dim * 4 + levels * 2) ) nn.init.orthogonal_(self.recurrent_kernel.weight) nn.init.zeros_(self.recurrent_kernel.bias) self.nn_scale = nn.Linear(int(self.hidden_dim), int(self.hidden_dim // 6)) self.nn_rescale = nn.Linear(int(self.hidden_dim // 6), int(self.hidden_dim)) self.nn_conv = nn.Conv1d( int(self.hidden_dim), int(self.conv_dim), int(conv_size), 1 ) # self.nn_output = nn.Linear(int(self.conv_dim), int(output_dim)) if self.dropconnect: self.nn_dropconnect = nn.Dropout(p=dropconnect) self.nn_dropconnect_r = nn.Dropout(p=dropconnect) if self.dropout: self.nn_dropout = nn.Dropout(p=dropout) self.nn_dropres = nn.Dropout(p=dropres)
[docs] def cumax(self, x, mode="l2r"): if mode == "l2r": x = torch.softmax(x, dim=-1) x = torch.cumsum(x, dim=-1) return x elif mode == "r2l": x = torch.flip(x, [-1]) x = torch.softmax(x, dim=-1) x = torch.cumsum(x, dim=-1) return torch.flip(x, [-1]) else: return x
[docs] def step(self, inputs, c_last, h_last, interval, device): x_in = inputs.to(device=device) # Integrate inter-visit time intervals interval = interval.unsqueeze(-1).to(device=device) x_out1 = self.kernel(torch.cat((x_in, interval), dim=-1)).to(device) x_out2 = self.recurrent_kernel( torch.cat((h_last.to(device=device), interval), dim=-1) ) if self.dropconnect: x_out1 = self.nn_dropconnect(x_out1) x_out2 = self.nn_dropconnect_r(x_out2) x_out = x_out1 + x_out2 f_master_gate = self.cumax(x_out[:, : self.levels], "l2r") f_master_gate = f_master_gate.unsqueeze(2).to(device=device) i_master_gate = self.cumax(x_out[:, self.levels : self.levels * 2], "r2l") i_master_gate = i_master_gate.unsqueeze(2) x_out = x_out[:, self.levels * 2 :] x_out = x_out.reshape(-1, self.levels * 4, self.chunk_size) f_gate = torch.sigmoid(x_out[:, : self.levels]).to(device=device) i_gate = torch.sigmoid(x_out[:, self.levels : self.levels * 2]).to( device=device ) o_gate = torch.sigmoid(x_out[:, self.levels * 2 : self.levels * 3]) c_in = torch.tanh(x_out[:, self.levels * 3 :]).to(device=device) c_last = c_last.reshape(-1, self.levels, self.chunk_size).to(device=device) overlap = (f_master_gate * i_master_gate).to(device=device) c_out = ( overlap * (f_gate * c_last + i_gate * c_in) + (f_master_gate - overlap) * c_last + (i_master_gate - overlap) * c_in ) h_out = o_gate * torch.tanh(c_out) c_out = c_out.reshape(-1, self.hidden_dim) h_out = h_out.reshape(-1, self.hidden_dim) out = torch.cat([h_out, f_master_gate[..., 0], i_master_gate[..., 0]], 1) return out, c_out, h_out
[docs] def forward( self, x: torch.tensor, time: Optional[torch.tensor] = None, mask: Optional[torch.tensor] = None, ) -> Tuple[torch.tensor]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, input_dim]. static: a tensor of shape [batch size, static_dim]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Returns: last_output: a tensor of shape [batch size, chunk_size*levels] representing the patient embedding. outputs: a tensor of shape [batch size, sequence len, chunk_size*levels] representing the patient at each time step. """ # rnn will only apply dropout between layers batch_size, time_step, feature_dim = x.size() device = x.device if time == None: time = torch.ones(batch_size, time_step) time = time.reshape(batch_size, time_step) c_out = torch.zeros(batch_size, self.hidden_dim) h_out = torch.zeros(batch_size, self.hidden_dim) tmp_h = ( torch.zeros_like(h_out, dtype=torch.float32) .view(-1) .repeat(self.conv_size) .view(self.conv_size, batch_size, self.hidden_dim) ) tmp_dis = torch.zeros((self.conv_size, batch_size)) h = [] origin_h = [] distance = [] for t in range(time_step): out, c_out, h_out = self.step(x[:, t, :], c_out, h_out, time[:, t], device) cur_distance = 1 - torch.mean( out[..., self.hidden_dim : self.hidden_dim + self.levels], -1 ) origin_h.append(out[..., : self.hidden_dim]) tmp_h = torch.cat( ( tmp_h[1:].to(device=device), out[..., : self.hidden_dim].unsqueeze(0).to(device=device), ), 0, ) tmp_dis = torch.cat( ( tmp_dis[1:].to(device=device), cur_distance.unsqueeze(0).to(device=device), ), 0, ) distance.append(cur_distance) # Re-weighted convolution operation local_dis = tmp_dis.permute(1, 0) local_dis = torch.cumsum(local_dis, dim=1) local_dis = torch.softmax(local_dis, dim=1) local_h = tmp_h.permute(1, 2, 0) local_h = local_h * local_dis.unsqueeze(1) # Re-calibrate Progression patterns local_theme = torch.mean(local_h, dim=-1) local_theme = self.nn_scale(local_theme).to(device) local_theme = torch.relu(local_theme) local_theme = self.nn_rescale(local_theme).to(device) local_theme = torch.sigmoid(local_theme) local_h = self.nn_conv(local_h).squeeze(-1) local_h = local_theme * local_h h.append(local_h) origin_h = torch.stack(origin_h).permute(1, 0, 2) rnn_outputs = torch.stack(h).permute(1, 0, 2) if self.dropres > 0.0: origin_h = self.nn_dropres(origin_h) rnn_outputs = rnn_outputs + origin_h rnn_outputs = rnn_outputs.contiguous().view(-1, rnn_outputs.size(-1)) if self.dropout > 0.0: rnn_outputs = self.nn_dropout(rnn_outputs) output = rnn_outputs.contiguous().view(batch_size, time_step, self.hidden_dim) last_output = get_last_visit(output, mask) return last_output, output, torch.stack(distance)
[docs]class StageNet(BaseModel): """StageNet model. Paper: Junyi Gao et al. Stagenet: Stage-aware neural networks for health risk prediction. WWW 2020. Note: We use separate StageNet layers for different feature_keys. Currently, we automatically support different input formats: - code based input (need to use the embedding table later) - float/int based value input We follow the current convention for the StageNet model: - case 1. [code1, code2, code3, ...] - we will assume the code follows the order; our model will encode each code into a vector and apply StageNet on the code level - case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], ...] - we will assume the inner bracket follows the order; our model first use the embedding table to encode each code into a vector and then use average/mean pooling to get one vector for one inner bracket; then use StageNet one the braket level - case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run StageNet directly on the inner bracket level, similar to case 1 after embedding table - case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run StageNet directly on the inner bracket level, similar to case 2 after embedding table The time interval information specified by time_keys will be used to calculate the memory decay between each visit. If time_keys is None, all visits are treated as the same time interval. For each feature, the time interval should be a two-dimensional float array with shape (time_step, 1). Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". time_keys: list of keys in samples to use as time interval information for each feature, Default is None. If none, all visits are treated as the same time interval. embedding_dim: the embedding dimension. Default is 128. chunk_size: the chunk size for the StageNet layer. Default is 128. levels: the number of levels for the StageNet layer. levels * chunk_size = hidden_dim in the RNN. Smaller chunk size and more levels can capture more detailed patient status variations. Default is 3. **kwargs: other parameters for the StageNet layer. Examples: >>> from pyhealth.datasets import SampleEHRDataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... # "single_vector": [1, 2, 3], ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 ... "list_list_vectors": [ ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], ... [[7.7, 8.5, 9.4]], ... ], ... "label": 1, ... "list_vectors_time": [[0.0], [1.3]], ... "list_codes_time": [[0.0], [2.0], [1.3]], ... "list_list_codes_time": [[0.0], [1.5]], ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... # "single_vector": [1, 5, 8], ... "list_codes": [ ... "55154191800", ... "551541928", ... "55154192800", ... "705182798", ... "70518279800", ... ], ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], ... "list_list_codes": [["A04A", "B035", "C129"]], ... "list_list_vectors": [ ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ... ], ... "label": 0, ... "list_vectors_time": [[0.0], [2.0], [1.0]], ... "list_codes_time": [[0.0], [2.0], [1.3], [1.0], [2.0]], ... "list_list_codes_time": [[0.0]], ... }, ... ] >>> >>> # dataset >>> dataset = SampleEHRDataset(samples=samples, dataset_name="test") >>> >>> # data loader >>> from pyhealth.datasets import get_dataloader >>> >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> >>> # model >>> model = StageNet( ... dataset=dataset, ... feature_keys=[ ... "list_codes", ... "list_vectors", ... "list_list_codes", ... # "list_list_vectors", ... ], ... time_keys=["list_codes_time", "list_vectors_time", "list_list_codes_time"], ... label_key="label", ... mode="binary", ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(0.7111, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'y_prob': tensor([[0.4815], [0.4991]], grad_fn=<SigmoidBackward0>), 'y_true': tensor([[1.], [0.]]), 'logit': tensor([[-0.0742], [-0.0038]], grad_fn=<AddmmBackward0>) } >>> """ def __init__( self, dataset: SampleEHRDataset, feature_keys: List[str], label_key: str, mode: str, time_keys: List[str] = None, embedding_dim: int = 128, chunk_size: int = 128, levels: int = 3, **kwargs, ): super(StageNet, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, mode=mode, ) self.embedding_dim = embedding_dim self.chunk_size = chunk_size self.levels = levels # validate kwargs for StageNet layer if "feature_size" in kwargs: raise ValueError("feature_size is determined by embedding_dim") if time_keys is not None: if len(time_keys) != len(feature_keys): raise ValueError( "time_keys should have the same length as feature_keys" ) # the key of self.feat_tokenizers only contains the code based inputs self.feat_tokenizers = {} self.time_keys = time_keys self.label_tokenizer = self.get_label_tokenizer() # the key of self.embeddings only contains the code based inputs self.embeddings = nn.ModuleDict() # the key of self.linear_layers only contains the float/int based inputs self.linear_layers = nn.ModuleDict() self.stagenet = nn.ModuleDict() # add feature StageNet layers for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] # sanity check if input_info["type"] not in [str, float, int]: raise ValueError( "StageNet only supports str code, float and int as input types" ) elif (input_info["type"] == str) and (input_info["dim"] not in [2, 3]): raise ValueError( "StageNet only supports 2-dim or 3-dim str code as input types" ) elif (input_info["type"] in [float, int]) and ( input_info["dim"] not in [2, 3] ): raise ValueError( "StageNet only supports 2-dim or 3-dim float and int as input types" ) # for code based input, we need Type # for float/int based input, we need Type, input_dim self.add_feature_transform_layer(feature_key, input_info) self.stagenet[feature_key] = StageNetLayer( input_dim=embedding_dim, chunk_size=self.chunk_size, levels=self.levels, **kwargs, ) output_size = self.get_output_size(self.label_tokenizer) self.fc = nn.Linear( len(self.feature_keys) * self.chunk_size * self.levels, output_size )
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. The label `kwargs[self.label_key]` is a list of labels for each patient. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Returns: A dictionary with the following keys: loss: a scalar tensor representing the final loss. distance: list of tensors representing the stage variation of the patient. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels. """ patient_emb = [] distance = [] mask_dict = {} for idx, feature_key in enumerate(self.feature_keys): input_info = self.dataset.input_info[feature_key] dim_, type_ = input_info["dim"], input_info["type"] # for case 1: [code1, code2, code3, ...] if (dim_ == 2) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_2d( kwargs[feature_key] ) # (patient, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, event) mask = torch.any(x !=0, dim=2) mask_dict[feature_key] = mask # for case 2: [[code1, code2], [code3, ...], ...] elif (dim_ == 3) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_3d( kwargs[feature_key] ) # (patient, visit, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, visit, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) # (patient, visit) mask = torch.any(x !=0, dim=2) mask_dict[feature_key] = mask # for case 3: [[1.5, 2.0, 0.0], ...] elif (dim_ == 2) and (type_ in [float, int]): x, mask = self.padding2d(kwargs[feature_key]) # (patient, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, event, embedding_dim) x = self.linear_layers[feature_key](x) # (patient, event) mask = mask.bool().to(self.device) mask_dict[feature_key] = mask # for case 4: [[[1.5, 2.0, 0.0], [1.8, 2.4, 6.0]], ...] elif (dim_ == 3) and (type_ in [float, int]): x, mask = self.padding3d(kwargs[feature_key]) # (patient, visit, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) x = self.linear_layers[feature_key](x) # (patient, event) mask = mask[:, :, 0] mask = mask.bool().to(self.device) mask_dict[feature_key] = mask else: raise NotImplementedError time = None if self.time_keys is not None: input_info = self.dataset.input_info[self.time_keys[idx]] dim_, type_ = input_info["dim"], input_info["type"] if (dim_ != 2) or (type_ not in [float, int]): raise ValueError("Time interval must be 2-dim float or int.") time, _ = self.padding2d(kwargs[self.time_keys[idx]]) time = torch.tensor(time, dtype=torch.float, device=self.device) x, _, cur_dis = self.stagenet[feature_key](x, time=time, mask=mask) patient_emb.append(x) distance.append(cur_dis) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = patient_emb return results
if __name__ == "__main__": from pyhealth.datasets import SampleEHRDataset samples = [ { "patient_id": "patient-0", "visit_id": "visit-0", # "single_vector": [1, 2, 3], "list_codes": ["505800458", "50580045810", "50580045811"], # NDC "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 "list_list_vectors": [ [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], [[7.7, 8.5, 9.4]], ], "label": 1, "list_vectors_time": [[0.0], [1.3]], "list_codes_time": [[0.0], [2.0], [1.3]], "list_list_codes_time": [[0.0], [1.5]], }, { "patient_id": "patient-0", "visit_id": "visit-1", # "single_vector": [1, 5, 8], "list_codes": [ "55154191800", "551541928", "55154192800", "705182798", "70518279800", ], "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], "list_list_codes": [["A04A", "B035", "C129"]], "list_list_vectors": [ [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ], "label": 0, "list_vectors_time": [[0.0], [2.0], [1.0]], "list_codes_time": [[0.0], [2.0], [1.3], [1.0], [2.0]], "list_list_codes_time": [[0.0]], }, ] # dataset dataset = SampleEHRDataset(samples=samples, dataset_name="test") # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) # model model = StageNet( dataset=dataset, feature_keys=[ "list_codes", "list_vectors", "list_list_codes", # "list_list_vectors", ], time_keys=["list_codes_time", "list_vectors_time", "list_list_codes_time"], label_key="label", mode="binary", ) # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()