Source code for pyhealth.models.transformer

import math
from typing import Dict, List, Optional, Tuple

import torch
from torch import nn

from pyhealth.datasets import SampleEHRDataset
from pyhealth.models import BaseModel
from pyhealth.tokenizer import Tokenizer

# VALID_OPERATION_LEVEL = ["visit", "event"]


class Attention(nn.Module):
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = torch.softmax(scores, dim=-1)
        if mask is not None:
            p_attn = p_attn.masked_fill(mask == 0, 0)
        if dropout is not None:
            p_attn = dropout(p_attn)
 
        return torch.matmul(p_attn, value), p_attn


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
    
        self.linear_layers = nn.ModuleList(
            [nn.Linear(d_model, d_model, bias=False) for _ in range(3)]
        )
        self.output_linear = nn.Linear(d_model, d_model, bias=False)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

        self.attn_gradients = None
        self.attn_map = None

    # helper functions for interpretability
    def get_attn_map(self):
        return self.attn_map 
    
    def get_attn_grad(self):
        return self.attn_gradients

    def save_attn_grad(self, attn_grad):
        self.attn_gradients = attn_grad 

    # register_hook option allows us to save the gradients in backwarding
    def forward(self, query, key, value, mask=None, register_hook = False):
        batch_size = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
            for l, x in zip(self.linear_layers, (query, key, value))
        ]
        
        # 2) Apply attention on all the projected vectors in batch.
        if mask is not None:
            mask = mask.unsqueeze(1)
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
        
        self.attn_map = attn # save the attention map
        if register_hook:
            attn.register_hook(self.save_attn_grad)
        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
  
        return self.output_linear(x)


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x, mask=None):
        x = self.w_2(self.dropout(self.activation(self.w_1(x))))
        if mask is not None:
            mask = mask.sum(dim=-1) > 0
            x[~mask] = 0
        return x


class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))


class TransformerBlock(nn.Module):
    """Transformer block.

    MultiHeadedAttention + PositionwiseFeedForward + SublayerConnection

    Args:
        hidden: hidden size of transformer.
        attn_heads: head sizes of multi-head attention.
        dropout: dropout rate.
    """

    def __init__(self, hidden, attn_heads, dropout):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
        self.feed_forward = PositionwiseFeedForward(
            d_model=hidden, d_ff=4 * hidden, dropout=dropout
        )
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask=None, register_hook = False):
        """Forward propagation.

        Args:
            x: [batch_size, seq_len, hidden]
            mask: [batch_size, seq_len, seq_len]

        Returns:
            A tensor of shape [batch_size, seq_len, hidden]
        """
        x = self.input_sublayer(x, lambda _x: self.attention(_x, _x, _x, mask=mask, register_hook=register_hook))
        x = self.output_sublayer(x, lambda _x: self.feed_forward(_x, mask=mask))
        return self.dropout(x)


[docs]class TransformerLayer(nn.Module): """Transformer layer. Paper: Ashish Vaswani et al. Attention is all you need. NIPS 2017. This layer is used in the Transformer model. But it can also be used as a standalone layer. Args: feature_size: the hidden feature size. heads: the number of attention heads. Default is 1. dropout: dropout rate. Default is 0.5. num_layers: number of transformer layers. Default is 1. register_hook: True to save gradients of attention layer, Default is False. Examples: >>> from pyhealth.models import TransformerLayer >>> input = torch.randn(3, 128, 64) # [batch size, sequence len, feature_size] >>> layer = TransformerLayer(64) >>> emb, cls_emb = layer(input) >>> emb.shape torch.Size([3, 128, 64]) >>> cls_emb.shape torch.Size([3, 64]) """ def __init__(self, feature_size, heads=1, dropout=0.5, num_layers=1): super(TransformerLayer, self).__init__() self.transformer = nn.ModuleList( [TransformerBlock(feature_size, heads, dropout) for _ in range(num_layers)] )
[docs] def forward( self, x: torch.tensor, mask: Optional[torch.tensor] = None, register_hook = False ) -> Tuple[torch.tensor, torch.tensor]: """Forward propagation. Args: x: a tensor of shape [batch size, sequence len, feature_size]. mask: an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid. Returns: emb: a tensor of shape [batch size, sequence len, feature_size], containing the output features for each time step. cls_emb: a tensor of shape [batch size, feature_size], containing the output features for the first time step. """ if mask is not None: mask = torch.einsum("ab,ac->abc", mask, mask) for transformer in self.transformer: x = transformer(x, mask, register_hook) emb = x cls_emb = x[:, 0, :] return emb, cls_emb
[docs]class Transformer(BaseModel): """Transformer model. This model applies a separate Transformer layer for each feature, and then concatenates the final hidden states of each Transformer layer. The concatenated hidden states are then fed into a fully connected layer to make predictions. Note: We use separate Transformer layers for different feature_keys. Currentluy, we automatically support different input formats: - code based input (need to use the embedding table later) - float/int based value input We follow the current convention for the transformer model: - case 1. [code1, code2, code3, ...] - we will assume the code follows the order; our model will encode each code into a vector and apply transformer on the code level - case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], ...] - we will assume the inner bracket follows the order; our model first use the embedding table to encode each code into a vector and then use average/mean pooling to get one vector for one inner bracket; then use transformer one the braket level - case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run transformer directly on the inner bracket level, similar to case 1 after embedding table - case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], ...] - this case only makes sense when each inner bracket has the same length; we assume each dimension has the same meaning; we run transformer directly on the inner bracket level, similar to case 2 after embedding table dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". embedding_dim: the embedding dimension. Default is 128. **kwargs: other parameters for the Transformer layer. Examples: >>> from pyhealth.datasets import SampleEHRDataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 ... "list_list_vectors": [ ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], ... [[7.7, 8.5, 9.4]], ... ], ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "list_codes": [ ... "55154191800", ... "551541928", ... "55154192800", ... "705182798", ... "70518279800", ... ], ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], ... "list_list_codes": [["A04A", "B035", "C129"]], ... "list_list_vectors": [ ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ... ], ... "label": 0, ... }, ... ] >>> dataset = SampleEHRDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import Transformer >>> model = Transformer( ... dataset=dataset, ... feature_keys=[ ... "list_codes", ... "list_vectors", ... "list_list_codes", ... "list_list_vectors", ... ], ... label_key="label", ... mode="multiclass", ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(4.0555, grad_fn=<NllLossBackward0>), 'y_prob': tensor([[1.0000e+00, 1.8206e-06], [9.9970e-01, 3.0020e-04]], grad_fn=<SoftmaxBackward0>), 'y_true': tensor([0, 1]), 'logit': tensor([[ 7.6283, -5.5881], [ 1.0898, -7.0210]], grad_fn=<AddmmBackward0>) } >>> """ def __init__( self, dataset: SampleEHRDataset, feature_keys: List[str], label_key: str, mode: str, pretrained_emb: str = None, embedding_dim: int = 128, **kwargs ): super(Transformer, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, mode=mode, pretrained_emb=pretrained_emb, ) self.embedding_dim = embedding_dim # validate kwargs for Transformer layer if "feature_size" in kwargs: raise ValueError("feature_size is determined by embedding_dim") # the key of self.feat_tokenizers only contains the code based inputs self.feat_tokenizers = {} self.label_tokenizer = self.get_label_tokenizer() # the key of self.embeddings only contains the code based inputs self.embeddings = nn.ModuleDict() # the key of self.linear_layers only contains the float/int based inputs self.linear_layers = nn.ModuleDict() # add feature transformation layers for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] # sanity check if input_info["type"] not in [str, float, int]: raise ValueError( "Transformer only supports str code, float and int as input types" ) elif (input_info["type"] == str) and (input_info["dim"] not in [2, 3]): raise ValueError( "Transformer only supports 2-dim or 3-dim str code as input types" ) elif (input_info["type"] in [float, int]) and ( input_info["dim"] not in [2, 3] ): raise ValueError( "Transformer only supports 2-dim or 3-dim float and int as input types" ) # for code based input, we need Type # for float/int based input, we need Type, input_dim self.add_feature_transform_layer(feature_key, input_info) self.transformer = nn.ModuleDict() for feature_key in feature_keys: self.transformer[feature_key] = TransformerLayer( feature_size=embedding_dim, **kwargs ) output_size = self.get_output_size(self.label_tokenizer) # transformer's output feature size is still embedding_dim self.fc = nn.Linear(len(self.feature_keys) * self.embedding_dim, output_size)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. The label `kwargs[self.label_key]` is a list of labels for each patient. Args: **kwargs: keyword arguments for the model. The keys must contain all the feature keys and the label key. Returns: A dictionary with the following keys: loss: a scalar tensor representing the loss. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels. """ patient_emb = [] for feature_key in self.feature_keys: input_info = self.dataset.input_info[feature_key] dim_, type_ = input_info["dim"], input_info["type"] # for case 1: [code1, code2, code3, ...] if (dim_ == 2) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_2d( kwargs[feature_key] ) # (patient, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, event) mask = torch.any(x !=0, dim=2) # for case 2: [[code1, code2], [code3, ...], ...] elif (dim_ == 3) and (type_ == str): x = self.feat_tokenizers[feature_key].batch_encode_3d( kwargs[feature_key] ) # (patient, visit, event) x = torch.tensor(x, dtype=torch.long, device=self.device) # (patient, visit, event, embedding_dim) x = self.embeddings[feature_key](x) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) # (patient, visit) mask = torch.any(x !=0, dim=2) # for case 3: [[1.5, 2.0, 0.0], ...] elif (dim_ == 2) and (type_ in [float, int]): x, mask = self.padding2d(kwargs[feature_key]) # (patient, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, event, embedding_dim) x = self.linear_layers[feature_key](x) # (patient, event) mask = mask.bool().to(self.device) # for case 4: [[[1.5, 2.0, 0.0], [1.8, 2.4, 6.0]], ...] elif (dim_ == 3) and (type_ in [float, int]): x, mask = self.padding3d(kwargs[feature_key]) # (patient, visit, event, values) x = torch.tensor(x, dtype=torch.float, device=self.device) # (patient, visit, embedding_dim) x = torch.sum(x, dim=2) x = self.linear_layers[feature_key](x) mask = mask[:, :, 0] mask = mask.bool().to(self.device) else: raise NotImplementedError # transform x to (patient, event, embedding_dim) if self.pretrained_emb != None: x = self.linear_layers[feature_key](x) _, x = self.transformer[feature_key](x, mask, kwargs.get('register_hook')) patient_emb.append(x) patient_emb = torch.cat(patient_emb, dim=1) logits = self.fc(patient_emb) # obtain y_true, loss, y_prob y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = {"loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits} if kwargs.get("embed", False): results["embed"] = patient_emb return results
if __name__ == "__main__": from pyhealth.datasets import SampleEHRDataset samples = [ { "patient_id": "patient-0", "visit_id": "visit-0", "single_vector": [1, 2, 3], "list_codes": ["505800458", "50580045810", "50580045811"], # NDC "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 "list_list_vectors": [ [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], [[7.7, 8.5, 9.4]], ], "label": 1, }, { "patient_id": "patient-0", "visit_id": "visit-1", "single_vector": [1, 5, 8], "list_codes": [ "55154191800", "551541928", "55154192800", "705182798", "70518279800", ], "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], "list_list_codes": [["A04A", "B035", "C129"]], "list_list_vectors": [ [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ], "label": 0, }, ] # dataset dataset = SampleEHRDataset(samples=samples, dataset_name="test") # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) # model model = Transformer( dataset=dataset, feature_keys=[ "list_codes", "list_vectors", "list_list_codes", "list_list_vectors", ], label_key="label", mode="multiclass", ) # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()