from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from pyhealth.datasets import SampleEHRDataset
from pyhealth.models import BaseModel
# VALID_OPERATION_LEVEL = ["visit", "event"]
[docs]class RETAINLayer(nn.Module):
"""RETAIN layer.
Paper: Edward Choi et al. RETAIN: An Interpretable Predictive Model for
Healthcare using Reverse Time Attention Mechanism. NIPS 2016.
This layer is used in the RETAIN model. But it can also be used as a
standalone layer.
Args:
feature_size: the hidden feature size.
dropout: dropout rate. Default is 0.5.
Examples:
>>> from pyhealth.models import RETAINLayer
>>> input = torch.randn(3, 128, 64) # [batch size, sequence len, feature_size]
>>> layer = RETAINLayer(64)
>>> c = layer(input)
>>> c.shape
torch.Size([3, 64])
"""
def __init__(
self,
feature_size: int,
dropout: float = 0.5,
):
super(RETAINLayer, self).__init__()
self.feature_size = feature_size
self.dropout = dropout
self.dropout_layer = nn.Dropout(p=self.dropout)
self.alpha_gru = nn.GRU(feature_size, feature_size, batch_first=True)
self.beta_gru = nn.GRU(feature_size, feature_size, batch_first=True)
self.alpha_li = nn.Linear(feature_size, 1)
self.beta_li = nn.Linear(feature_size, feature_size)
[docs] @staticmethod
def reverse_x(input, lengths):
"""Reverses the input."""
reversed_input = input.new(input.size())
for i, length in enumerate(lengths):
reversed_input[i, :length] = input[i, :length].flip(dims=[0])
return reversed_input
[docs] def compute_alpha(self, rx, lengths):
"""Computes alpha attention."""
rx = rnn_utils.pack_padded_sequence(
rx, lengths, batch_first=True, enforce_sorted=False
)
g, _ = self.alpha_gru(rx)
g, _ = rnn_utils.pad_packed_sequence(g, batch_first=True)
attn_alpha = torch.softmax(self.alpha_li(g), dim=1)
return attn_alpha
[docs] def compute_beta(self, rx, lengths):
"""Computes beta attention."""
rx = rnn_utils.pack_padded_sequence(
rx, lengths, batch_first=True, enforce_sorted=False
)
h, _ = self.beta_gru(rx)
h, _ = rnn_utils.pad_packed_sequence(h, batch_first=True)
attn_beta = torch.tanh(self.beta_li(h))
return attn_beta
[docs] def forward(
self,
x: torch.tensor,
mask: Optional[torch.tensor] = None,
) -> Tuple[torch.tensor, torch.tensor]:
"""Forward propagation.
Args:
x: a tensor of shape [batch size, sequence len, feature_size].
mask: an optional tensor of shape [batch size, sequence len], where
1 indicates valid and 0 indicates invalid.
Returns:
c: a tensor of shape [batch size, feature_size] representing the
context vector.
"""
# rnn will only apply dropout between layers
x = self.dropout_layer(x)
batch_size = x.size(0)
if mask is None:
lengths = torch.full(
size=(batch_size,), fill_value=x.size(1), dtype=torch.int64
)
else:
lengths = torch.sum(mask.int(), dim=-1).cpu()
rx = self.reverse_x(x, lengths)
attn_alpha = self.compute_alpha(rx, lengths)
attn_beta = self.compute_beta(rx, lengths)
c = attn_alpha * attn_beta * x # (patient, sequence len, feature_size)
c = torch.sum(c, dim=1) # (patient, feature_size)
return c
[docs]class RETAIN(BaseModel):
"""RETAIN model.
Paper: Edward Choi et al. RETAIN: An Interpretable Predictive Model for
Healthcare using Reverse Time Attention Mechanism. NIPS 2016.
Note:
We use separate Retain layers for different feature_keys.
Currentluy, we automatically support different input formats:
- code based input (need to use the embedding table later)
- float/int based value input
We follow the current convention for the Retain model:
- case 1. [code1, code2, code3, ...]
- we will assume the code follows the order; our model will encode
each code into a vector and apply Retain on the code level
- case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], ...]
- we will assume the inner bracket follows the order; our model first
use the embedding table to encode each code into a vector and then use
average/mean pooling to get one vector for one inner bracket; then use
Retain one the braket level
- case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], ...]
- this case only makes sense when each inner bracket has the same length;
we assume each dimension has the same meaning; we run Retain directly
on the inner bracket level, similar to case 1 after embedding table
- case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], ...]
- this case only makes sense when each inner bracket has the same length;
we assume each dimension has the same meaning; we run Retain directly
on the inner bracket level, similar to case 2 after embedding table
Args:
dataset: the dataset to train the model. It is used to query certain
information such as the set of all tokens.
feature_keys: list of keys in samples to use as features,
e.g. ["conditions", "procedures"].
label_key: key in samples to use as label (e.g., "drugs").
mode: one of "binary", "multiclass", or "multilabel".
embedding_dim: the embedding dimension. Default is 128.
**kwargs: other parameters for the RETAIN layer.
Examples:
>>> from pyhealth.datasets import SampleEHRDataset
>>> samples = [
... {
... "patient_id": "patient-0",
... "visit_id": "visit-0",
... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC
... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]],
... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4
... "list_list_vectors": [
... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]],
... [[7.7, 8.5, 9.4]],
... ],
... "label": 1,
... },
... {
... "patient_id": "patient-0",
... "visit_id": "visit-1",
... "list_codes": [
... "55154191800",
... "551541928",
... "55154192800",
... "705182798",
... "70518279800",
... ],
... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]],
... "list_list_codes": [["A04A", "B035", "C129"]],
... "list_list_vectors": [
... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]],
... ],
... "label": 0,
... },
... ]
>>> dataset = SampleEHRDataset(samples=samples, dataset_name="test")
>>>
>>> from pyhealth.models import RETAIN
>>> model = RETAIN(
... dataset=dataset,
... feature_keys=[
... "list_codes",
... "list_vectors",
... "list_list_codes",
... "list_list_vectors",
... ],
... label_key="label",
... mode="binary",
... )
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
>>> data_batch = next(iter(train_loader))
>>>
>>> ret = model(**data_batch)
>>> print(ret)
{
'loss': tensor(0.5640, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
'y_prob': tensor([[0.5325],
[0.3922]], grad_fn=<SigmoidBackward0>),
'y_true': tensor([[1.],
[0.]]),
'logit': tensor([[ 0.1303],
[-0.4382]], grad_fn=<AddmmBackward0>)
}
>>>
"""
def __init__(
self,
dataset: SampleEHRDataset,
feature_keys: List[str],
label_key: str,
mode: str,
pretrained_emb: str = None,
embedding_dim: int = 128,
**kwargs,
):
super(RETAIN, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
mode=mode,
pretrained_emb=pretrained_emb,
)
self.embedding_dim = embedding_dim
# validate kwargs for RETAIN layer
if "feature_size" in kwargs:
raise ValueError("feature_size is determined by embedding_dim")
# the key of self.feat_tokenizers only contains the code based inputs
self.feat_tokenizers = {}
self.label_tokenizer = self.get_label_tokenizer()
# the key of self.embeddings only contains the code based inputs
self.embeddings = nn.ModuleDict()
# the key of self.linear_layers only contains the float/int based inputs
self.linear_layers = nn.ModuleDict()
# add feature RETAIN layers
for feature_key in self.feature_keys:
input_info = self.dataset.input_info[feature_key]
# sanity check
if input_info["type"] not in [str, float, int]:
raise ValueError(
"RETAIN only supports str code, float and int as input types"
)
elif (input_info["type"] == str) and (input_info["dim"] not in [2, 3]):
raise ValueError(
"RETAIN only supports 2-dim or 3-dim str code as input types"
)
elif (input_info["type"] in [float, int]) and (
input_info["dim"] not in [2, 3]
):
raise ValueError(
"RETAIN only supports 2-dim or 3-dim float and int as input types"
)
# for code based input, we need Type
# for float/int based input, we need Type, input_dim
self.add_feature_transform_layer(feature_key, input_info)
self.retain = nn.ModuleDict()
for feature_key in feature_keys:
self.retain[feature_key] = RETAINLayer(feature_size=embedding_dim, **kwargs)
output_size = self.get_output_size(self.label_tokenizer)
self.fc = nn.Linear(len(self.feature_keys) * self.embedding_dim, output_size)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation.
The label `kwargs[self.label_key]` is a list of labels for each patient.
Args:
**kwargs: keyword arguments for the model. The keys must contain
all the feature keys and the label key.
Returns:
A dictionary with the following keys:
loss: a scalar tensor representing the loss.
y_prob: a tensor representing the predicted probabilities.
y_true: a tensor representing the true labels.
"""
patient_emb = []
for feature_key in self.feature_keys:
input_info = self.dataset.input_info[feature_key]
dim_, type_ = input_info["dim"], input_info["type"]
# for case 1: [code1, code2, code3, ...]
if (dim_ == 2) and (type_ == str):
x = self.feat_tokenizers[feature_key].batch_encode_2d(
kwargs[feature_key]
)
# (patient, event)
x = torch.tensor(x, dtype=torch.long, device=self.device)
# (patient, event, embedding_dim)
x = self.embeddings[feature_key](x)
# (patient, event)
mask = torch.sum(x, dim=2) != 0
# for case 2: [[code1, code2], [code3, ...], ...]
elif (dim_ == 3) and (type_ == str):
x = self.feat_tokenizers[feature_key].batch_encode_3d(
kwargs[feature_key]
)
# (patient, visit, event)
x = torch.tensor(x, dtype=torch.long, device=self.device)
# (patient, visit, event, embedding_dim)
x = self.embeddings[feature_key](x)
# (patient, visit, embedding_dim)
x = torch.sum(x, dim=2)
# (patient, visit)
mask = torch.sum(x, dim=2) != 0
# for case 3: [[1.5, 2.0, 0.0], ...]
elif (dim_ == 2) and (type_ in [float, int]):
x, mask = self.padding2d(kwargs[feature_key])
# (patient, event, values)
x = torch.tensor(x, dtype=torch.float, device=self.device)
# (patient, event, embedding_dim)
x = self.linear_layers[feature_key](x)
# (patient, event)
mask = mask.bool().to(self.device)
# for case 4: [[[1.5, 2.0, 0.0], [1.8, 2.4, 6.0]], ...]
elif (dim_ == 3) and (type_ in [float, int]):
x, mask = self.padding3d(kwargs[feature_key])
# (patient, visit, event, values)
x = torch.tensor(x, dtype=torch.float, device=self.device)
# (patient, visit, embedding_dim)
x = torch.sum(x, dim=2)
x = self.linear_layers[feature_key](x)
# (patient, event)
mask = mask[:, :, 0]
mask = mask.bool().to(self.device)
else:
raise NotImplementedError
# transform x to (patient, event, embedding_dim)
if self.pretrained_emb != None:
x = self.linear_layers[feature_key](x)
x = self.retain[feature_key](x, mask)
patient_emb.append(x)
patient_emb = torch.cat(patient_emb, dim=1)
# (patient, label_size)
logits = self.fc(patient_emb)
# obtain y_true, loss, y_prob
y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
loss = self.get_loss_function()(logits, y_true)
y_prob = self.prepare_y_prob(logits)
results = {
"loss": loss,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
}
if kwargs.get("embed", False):
results["embed"] = patient_emb
return results
if __name__ == "__main__":
from pyhealth.datasets import SampleEHRDataset
samples = [
{
"patient_id": "patient-0",
"visit_id": "visit-0",
# "single_vector": [1, 2, 3],
"list_codes": ["505800458", "50580045810", "50580045811"], # NDC
"list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]],
"list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4
"list_list_vectors": [
[[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]],
[[7.7, 8.5, 9.4]],
],
"label": 1,
},
{
"patient_id": "patient-0",
"visit_id": "visit-1",
# "single_vector": [1, 5, 8],
"list_codes": [
"55154191800",
"551541928",
"55154192800",
"705182798",
"70518279800",
],
"list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]],
"list_list_codes": [["A04A", "B035", "C129"]],
"list_list_vectors": [
[[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]],
],
"label": 0,
},
]
# dataset
dataset = SampleEHRDataset(samples=samples, dataset_name="test")
# data loader
from pyhealth.datasets import get_dataloader
train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
# model
model = RETAIN(
dataset=dataset,
feature_keys=[
"list_codes",
"list_vectors",
"list_list_codes",
# "list_list_vectors",
],
label_key="label",
mode="binary",
)
# data batch
data_batch = next(iter(train_loader))
# try the model
ret = model(**data_batch)
print(ret)
# try loss backward
ret["loss"].backward()