Source code for pyhealth.models.deepr

import functools
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn

from pyhealth.datasets import BaseEHRDataset
from pyhealth.models import BaseModel


[docs]class DeeprLayer(nn.Module): """Deepr layer. Paper: P. Nguyen, T. Tran, N. Wickramasinghe and S. Venkatesh, " Deepr : A Convolutional Net for Medical Records," in IEEE Journal of Biomedical and Health Informatics, vol. 21, no. 1, pp. 22-30, Jan. 2017, doi: 10.1109/JBHI.2016.2633963. This layer is used in the Deepr model. Args: feature_size: embedding dim of codes (m in the original paper). window: sliding window (d in the original paper) hidden_size: number of conv filters (motif size, p, in the original paper) Examples: >>> from pyhealth.models import DeeprLayer >>> input = torch.randn(3, 128, 5) # [batch size, sequence len, input_size] >>> layer = DeeprLayer(5, window=4, hidden_size=7) # window does not impact the output shape >>> outputs = layer(input) >>> outputs.shape torch.Size([3, 7]) """ def __init__( self, feature_size: int = 100, window: int = 1, hidden_size: int = 3, ): super(DeeprLayer, self).__init__() self.conv = torch.nn.Conv1d( feature_size, hidden_size, kernel_size=2 * window + 1 )
[docs] def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """Forward propagation. Args: x: a Tensor of shape [batch size, sequence len, input size]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Returns: c: a Tensor of shape [batch size, hidden_size] representing the summarized vector. """ if mask is not None: x = x * mask.unsqueeze(-1) x = x.permute(0, 2, 1) # [batch size, input size, sequence len] x = torch.relu(self.conv(x)) x = x.max(-1)[0] return x
def _flatten_and_fill_gap(gap_embedding, batch, device): """Helper function to fill <gap> embedding into a batch of data.""" embed_dim = gap_embedding.shape[-1] batch = [ [ [torch.tensor(_, device=device, dtype=torch.float) for _ in _visit_x] for _visit_x in _pat_x ] for _pat_x in batch ] batch = [ torch.stack(functools.reduce(lambda a, b: a + [gap_embedding] + b, _), 0) for _ in batch ] batch_max_length = max(map(len, batch)) mask = torch.tensor( [[1] * len(x) + [0] * (batch_max_length - len(x)) for x in batch], dtype=torch.long, device=device, ) out = torch.zeros( [len(batch), batch_max_length, embed_dim], device=device, dtype=torch.float ) for i, x in enumerate(batch): out[i, : len(x)] = x return out, mask
[docs]class Deepr(BaseModel): """Deepr model. Paper: P. Nguyen, T. Tran, N. Wickramasinghe and S. Venkatesh, " Deepr : A Convolutional Net for Medical Records," in IEEE Journal of Biomedical and Health Informatics, vol. 21, no. 1, pp. 22-30, Jan. 2017, doi: 10.1109/JBHI.2016.2633963. Note: We use separate Deepr layers for different feature_keys. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". embedding_dim: the embedding dimension. Default is 128. hidden_dim: the hidden dimension. Default is 128. **kwargs: other parameters for the Deepr layer. Examples: >>> from pyhealth.datasets import SampleEHRDataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 ... "list_list_vectors": [ ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], ... [[7.7, 8.5, 9.4]], ... ], ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "list_codes": [ ... "55154191800", ... "551541928", ... "55154192800", ... "705182798", ... "70518279800", ... ], ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], ... "list_list_codes": [["A04A", "B035", "C129"]], ... "list_list_vectors": [ ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ... ], ... "label": 0, ... }, ... ] >>> dataset = SampleEHRDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import Deepr >>> model = Deepr( ... dataset=dataset, ... feature_keys=[ ... "list_list_codes", ... "list_list_vectors", ... ], ... label_key="label", ... mode="binary", ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(0.8908, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'y_prob': tensor([[0.2295], [0.2665]], device='cuda:0', grad_fn=<SigmoidBackward0>), 'y_true': tensor([[1.], [0.]], device='cuda:0'), 'logit': tensor([[-1.2110], [-1.0126]], device='cuda:0', grad_fn=<AddmmBackward0>) } """ def __init__( self, dataset: BaseEHRDataset, feature_keys: List[str], label_key: str, mode: str, embedding_dim: int = 128, hidden_dim: int = 128, **kwargs, ): super(Deepr, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, mode=mode, ) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim # TODO: Use more tokens for <gap> for different lengths once the input has such information self.feat_tokenizers = {} self.label_tokenizer = self.get_label_tokenizer() # TODO: Pretrain this embeddings with word2vec? self.embeddings = nn.ModuleDict() # the key of self.linear_layers only contains the float/int based inputs self.linear_layers = nn.ModuleDict() # add feature Deepr layers for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] # sanity check if input_info["type"] not in [str, float, int]: raise ValueError( "Deepr only supports str code, float and int as input types" ) if (input_info["type"] == str) and (input_info["dim"] != 3): raise ValueError("Deepr only supports 2-level str code as input types") if (input_info["type"] in [float, int]) and (input_info["dim"] != 3): raise ValueError( "Deepr only supports 3-level float and int as input types" ) # for code based input, we need Type # for float/int based input, we need Type, input_dim self.add_feature_transform_layer( feature_key, input_info, special_tokens=["<pad>", "<unk>", "<gap>"] ) if input_info["type"] != str: self.embeddings[feature_key] = torch.nn.Embedding(1, input_info["len"]) self.cnn = nn.ModuleDict() for feature_key in feature_keys: self.cnn[feature_key] = DeeprLayer( feature_size=embedding_dim, hidden_size=hidden_dim, **kwargs ) output_size = self.get_output_size(self.label_tokenizer) self.fc = nn.Linear(len(self.feature_keys) * self.hidden_dim, output_size)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation.""" patient_emb = [] for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] dim_, type_ = input_info["dim"], input_info["type"] # for case 2: [[code1, code2], [code3, ...], ...] if (dim_ == 3) and (type_ == str): feature_vals = [ functools.reduce(lambda a, b: a + ["<gap>"] + b, _) for _ in kwargs[feature_key] ] x = self.feat_tokenizers[feature_key].batch_encode_2d( feature_vals, padding=True, truncation=False ) pad_idx = self.feat_tokenizers[feature_key].vocabulary("<pad>") mask = torch.tensor( [[_code != pad_idx for _code in _pat] for _pat in x], dtype=torch.long, device=self.device, ) # (patient, code) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, event, embedding_dim) x = self.embeddings[feature_key](x) # for case 4: [[[1.5, 2.0, 0.0], [1.8, 2.4, 6.0]], ...] elif (dim_ == 3) and (type_ in [float, int]): gap_embedding = self.embeddings[feature_key]( torch.zeros(1, dtype=torch.long, device=self.device) )[0] x, mask = _flatten_and_fill_gap( gap_embedding, kwargs[feature_key], self.device ) # (patient, event, embedding_dim) x = self.linear_layers[feature_key](x) else: raise NotImplementedError( f"Deepr does not support this input format (dim={dim_}, type={type_})." ) # (patient, hidden_dim) x = self.cnn[feature_key](x, mask) patient_emb.append(x) # (patient, features * hidden_dim) patient_emb = torch.cat(patient_emb, dim=1) # (patient, label_size) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = patient_emb return results
if __name__ == "__main__": from pyhealth.datasets import SampleEHRDataset samples = [ { "patient_id": "patient-0", "visit_id": "visit-0", "single_vector": [1, 2, 3], "list_codes": ["505800458", "50580045810", "50580045811"], # NDC "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 "list_list_vectors": [ [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], [[7.7, 8.5, 9.4]], ], "label": 1, }, { "patient_id": "patient-0", "visit_id": "visit-1", "single_vector": [1, 5, 8], "list_codes": [ "55154191800", "551541928", "55154192800", "705182798", "70518279800", ], "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], "list_list_codes": [["A04A", "B035", "C129"]], "list_list_vectors": [ [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ], "label": 0, }, ] # dataset dataset = SampleEHRDataset(samples=samples, dataset_name="test") # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) # model model = Deepr( dataset=dataset, # feature_keys=["procedures"], feature_keys=["list_list_codes", "list_list_vectors"], label_key="label", mode="binary", ).to("cuda:0") # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()