"""GRASP model for health status representation learning.
This module implements the GRASP (Generic fRAmework for health Status
representation learning based on incorporating knowledge from Similar
Patients) model from Ma et al., AAAI 2021.
The model clusters patient representations via k-means, refines
cluster-level knowledge with a graph convolutional network, and blends
it back into individual patient embeddings through a learned gate.
"""
import copy
import math
import random
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from sklearn.neighbors import kneighbors_graph
from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel
from pyhealth.models.concare import ConCareLayer
from pyhealth.models.embedding import EmbeddingModel
from pyhealth.models.rnn import RNNLayer
def random_init(
dataset: torch.Tensor, num_centers: int, device: torch.device
) -> torch.Tensor:
"""Randomly select initial cluster centers from the dataset.
Args:
dataset: tensor of shape [num_points, dimension].
num_centers: number of cluster centers to select.
device: target device for the output tensor.
Returns:
Tensor of shape [num_centers, dimension] with selected centers.
"""
num_points = dataset.size(0)
dimension = dataset.size(1)
num_centers = min(num_centers, num_points)
indices = torch.tensor(
np.array(random.sample(range(num_points), k=num_centers)), dtype=torch.long
)
centers = torch.gather(
dataset, 0, indices.view(-1, 1).expand(-1, dimension).to(device=device)
)
return centers
def compute_codes(
dataset: torch.Tensor, centers: torch.Tensor
) -> torch.Tensor:
"""Assign each data point to its closest cluster center.
Args:
dataset: tensor of shape [num_points, dimension].
centers: tensor of shape [num_centers, dimension].
Returns:
Long tensor of shape [num_points] with cluster assignments.
"""
num_points = dataset.size(0)
dimension = dataset.size(1)
num_centers = centers.size(0)
# 5e8 should vary depending on the free memory on the GPU
# Ideally, automatically ;)
chunk_size = int(5e8 / num_centers)
codes = torch.zeros(num_points, dtype=torch.long)
centers_t = torch.transpose(centers, 0, 1)
centers_norms = torch.sum(centers**2, dim=1).view(1, -1)
for i in range(0, num_points, chunk_size):
begin = i
end = min(begin + chunk_size, num_points)
dataset_piece = dataset[begin:end, :]
dataset_norms = torch.sum(dataset_piece**2, dim=1).view(-1, 1)
distances = torch.mm(dataset_piece, centers_t)
distances *= -2.0
distances += dataset_norms
distances += centers_norms
_, min_ind = torch.min(distances, dim=1)
codes[begin:end] = min_ind
return codes
def update_centers(
dataset: torch.Tensor,
codes: torch.Tensor,
num_centers: int,
device: torch.device,
) -> torch.Tensor:
"""Recompute cluster centers as the mean of assigned data points.
Args:
dataset: tensor of shape [num_points, dimension].
codes: long tensor of shape [num_points] with cluster assignments.
num_centers: number of clusters.
device: target device for the output tensor.
Returns:
Tensor of shape [num_centers, dimension] with updated centers.
"""
num_points = dataset.size(0)
dimension = dataset.size(1)
centers = torch.zeros(num_centers, dimension, dtype=torch.float).to(device=device)
cnt = torch.zeros(num_centers, dtype=torch.float)
centers.scatter_add_(
0, codes.view(-1, 1).expand(-1, dimension).to(device=device), dataset
)
cnt.scatter_add_(0, codes, torch.ones(num_points, dtype=torch.float))
# Avoiding division by zero
# Not necessary if there are no duplicates among the data points
cnt = torch.where(cnt > 0.5, cnt, torch.ones(num_centers, dtype=torch.float))
centers /= cnt.view(-1, 1).to(device=device)
return centers
def cluster(
dataset: torch.Tensor, num_centers: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Run k-means clustering until convergence or 1000 iterations.
Args:
dataset: tensor of shape [num_points, dimension].
num_centers: number of clusters.
device: target device for computation.
Returns:
Tuple of (centers, codes) where centers has shape
[num_centers, dimension] and codes has shape [num_points].
"""
centers = random_init(dataset, num_centers, device)
codes = compute_codes(dataset, centers)
num_iterations = 0
while True:
num_iterations += 1
centers = update_centers(dataset, codes, num_centers, device)
new_codes = compute_codes(dataset, centers)
# Waiting until the clustering stops updating altogether
# This is too strict in practice
if torch.equal(codes, new_codes):
break
if num_iterations > 1000:
break
codes = new_codes
return centers, codes
class GraphConvolution(nn.Module):
"""Single-layer graph convolution (Kipf & Welling, ICLR 2017).
Args:
in_features: size of each input sample.
out_features: size of each output sample.
bias: if ``True``, adds a learnable bias. Default: ``True``.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(in_features, out_features).float())
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features).float())
else:
self.register_parameter("bias", None)
self.initialize_parameters()
def initialize_parameters(self):
std = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-std, std)
if self.bias is not None:
self.bias.data.uniform_(-std, std)
def forward(self, adj, x, device):
y = torch.mm(x.float(), self.weight.float())
output = torch.mm(adj.float(), y.float())
if self.bias is not None:
return output + self.bias.float().to(device=device)
else:
return output
def clones(module, N):
"Produce N identical layers."
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
[docs]class GRASPLayer(nn.Module):
"""GRASPLayer layer.
Paper: Liantao Ma et al. GRASP: generic framework for health status representation learning based on incorporating knowledge from similar patients. AAAI 2021.
This layer is used in the GRASP model. But it can also be used as a
standalone layer.
Args:
input_dim: dynamic feature size.
static_dim: static feature size, if 0, then no static feature is used.
hidden_dim: hidden dimension of the GRASP layer, default 128.
cluster_num: number of clusters, default 12. The cluster_num should be no more than the number of samples.
dropout: dropout rate, default 0.5.
block: the backbone model used in the GRASP layer
('ConCare', 'LSTM' or 'GRU'), default 'ConCare'.
Examples:
>>> from pyhealth.models import GRASPLayer
>>> x = torch.randn(3, 128, 64) # [batch, seq_len, feature_size]
>>> layer = GRASPLayer(64, cluster_num=2)
>>> c = layer(x)
>>> c.shape
torch.Size([3, 128])
"""
def __init__(
self,
input_dim: int,
static_dim: int = 0,
hidden_dim: int = 128,
cluster_num: int = 2,
dropout: float = 0.5,
block: str = "ConCare",
):
super(GRASPLayer, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.cluster_num = cluster_num
self.dropout = dropout
self.block = block
if self.block == "ConCare":
self.backbone = ConCareLayer(
input_dim, static_dim, hidden_dim, hidden_dim, dropout=0
)
elif self.block == "GRU":
self.backbone = RNNLayer(input_dim, hidden_dim, rnn_type="GRU", dropout=0)
elif self.block == "LSTM":
self.backbone = RNNLayer(input_dim, hidden_dim, rnn_type="LSTM", dropout=0)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(dropout)
self.weight1 = nn.Linear(self.hidden_dim, 1)
self.weight2 = nn.Linear(self.hidden_dim, 1)
self.GCN = GraphConvolution(self.hidden_dim, self.hidden_dim, bias=True)
self.GCN.initialize_parameters()
self.GCN_2 = GraphConvolution(self.hidden_dim, self.hidden_dim, bias=True)
self.GCN_2.initialize_parameters()
self.A_mat = None
self.bn = nn.BatchNorm1d(self.hidden_dim)
[docs] def sample_gumbel(self, shape, eps=1e-20):
U = torch.rand(shape)
return -torch.log(-torch.log(U + eps) + eps)
[docs] def gumbel_softmax_sample(self, logits, temperature, device):
y = logits + self.sample_gumbel(logits.size()).to(device=device)
return torch.softmax(y / temperature, dim=-1)
[docs] def gumbel_softmax(self, logits, temperature, device, hard=False):
"""
ST-gumple-softmax
input: [*, n_class]
return: flatten --> [*, n_class] an one-hot vector
"""
y = self.gumbel_softmax_sample(logits, temperature, device)
if not hard:
return y.view(-1, self.cluster_num)
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
# Set gradients w.r.t. y_hard gradients w.r.t. y
y_hard = (y_hard - y).detach() + y
return y_hard
[docs] def grasp_encoder(
self,
input: torch.Tensor,
static: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Encode patient sequences with backbone + cluster-aware GCN.
Args:
input: tensor of shape [batch_size, seq_len, input_dim].
static: optional static features [batch_size, static_dim].
mask: optional mask [batch_size, seq_len].
Returns:
Tensor of shape [batch_size, hidden_dim].
"""
if self.block == "ConCare":
hidden_t, _ = self.backbone(input, mask=mask, static=static)
else:
_, hidden_t = self.backbone(input, mask)
centers, codes = cluster(hidden_t, self.cluster_num, input.device)
if self.A_mat is None:
A_mat = np.eye(self.cluster_num)
else:
A_mat = kneighbors_graph(
np.array(centers.detach().cpu().numpy()),
20,
mode="connectivity",
include_self=False,
).toarray()
adj_mat = torch.tensor(A_mat).to(device=input.device)
e = self.relu(torch.matmul(hidden_t, centers.transpose(0, 1))) # b clu_num
scores = self.gumbel_softmax(e, temperature=1, device=input.device, hard=True)
digits = torch.argmax(scores, dim=-1) # b
h_prime = self.relu(self.GCN(adj_mat, centers, input.device))
h_prime = self.relu(self.GCN_2(adj_mat, h_prime, input.device))
clu_appendix = torch.matmul(scores, h_prime)
weight1 = torch.sigmoid(self.weight1(clu_appendix))
weight2 = torch.sigmoid(self.weight2(hidden_t))
weight1 = weight1 / (weight1 + weight2)
weight2 = 1 - weight1
final_h = weight1 * clu_appendix + weight2 * hidden_t
out = final_h
return out
[docs] def forward(
self,
x: torch.tensor,
static: Optional[torch.tensor] = None,
mask: Optional[torch.tensor] = None,
) -> torch.tensor:
"""Forward propagation.
Args:
x: a tensor of shape [batch size, sequence len, input_dim].
static: a tensor of shape [batch size, static_dim].
mask: an optional tensor of shape [batch size, sequence len], where
1 indicates valid and 0 indicates invalid.
Returns:
output: a tensor of shape [batch size, fusion_dim] representing the
patient embedding.
"""
# rnn will only apply dropout between layers
out = self.grasp_encoder(x, static, mask)
out = self.dropout(out)
return out
[docs]class GRASP(BaseModel):
"""GRASP model.
Paper: Liantao Ma et al. GRASP: generic framework for health status
representation learning based on incorporating knowledge from
similar patients. AAAI 2021.
This model applies a separate GRASP layer for each feature, and then
concatenates the outputs. The concatenated representations are fed into
a fully connected layer to make predictions.
The GRASP layer encodes patient sequences with a backbone (ConCare, GRU,
or LSTM), clusters patients via k-means, refines cluster representations
with a 2-layer GCN, and blends cluster-level knowledge back into
individual patient representations via a learned gating mechanism.
Args:
dataset (SampleDataset): the dataset to train the model. It is used
to query certain information such as the set of all tokens.
static_key (str): optional key in samples to use as static features,
e.g. "demographics". Only numerical static features are supported.
Default is None.
embedding_dim (int): the embedding dimension. Default is 128.
hidden_dim (int): the hidden dimension. Default is 128.
**kwargs: other parameters for the GRASPLayer
(e.g., cluster_num, dropout, block).
Examples:
>>> from pyhealth.datasets import create_sample_dataset
>>> samples = [
... {
... "patient_id": "patient-0",
... "visit_id": "visit-0",
... "conditions": ["cond-33", "cond-86", "cond-80"],
... "procedures": ["proc-12", "proc-45"],
... "label": 1,
... },
... {
... "patient_id": "patient-1",
... "visit_id": "visit-1",
... "conditions": ["cond-12", "cond-52"],
... "procedures": ["proc-23"],
... "label": 0,
... },
... ]
>>> dataset = create_sample_dataset(
... samples=samples,
... input_schema={
... "conditions": "sequence",
... "procedures": "sequence",
... },
... output_schema={"label": "binary"},
... dataset_name="test",
... )
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
>>>
>>> model = GRASP(
... dataset=dataset,
... embedding_dim=128,
... hidden_dim=64,
... cluster_num=2,
... )
>>>
>>> data_batch = next(iter(train_loader))
>>>
>>> ret = model(**data_batch)
>>> print(ret)
{
'loss': tensor(...),
'y_prob': tensor(...),
'y_true': tensor(...),
'logit': tensor(...)
}
"""
def __init__(
self,
dataset: SampleDataset,
static_key: Optional[str] = None,
embedding_dim: int = 128,
hidden_dim: int = 128,
**kwargs
):
super(GRASP, self).__init__(
dataset=dataset,
)
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.static_key = static_key
# validate kwargs for GRASP layer
if "input_dim" in kwargs:
raise ValueError("input_dim is determined by embedding_dim")
assert len(self.label_keys) == 1, "Only one label key is supported"
self.label_key = self.label_keys[0]
self.mode = self.dataset.output_schema[self.label_key]
self.embedding_model = EmbeddingModel(dataset, embedding_dim)
# Determine static feature dimension
self.static_dim = 0
if self.static_key is not None:
first_sample = dataset[0]
if self.static_key in first_sample:
static_val = first_sample[self.static_key]
if isinstance(static_val, torch.Tensor):
self.static_dim = (
static_val.shape[-1] if static_val.dim() > 0 else 1
)
elif isinstance(static_val, (list, tuple)):
self.static_dim = len(static_val)
else:
self.static_dim = 1
# Dynamic feature keys (exclude static key)
self.dynamic_feature_keys = [
k for k in self.feature_keys if k != self.static_key
]
# one GRASPLayer per dynamic feature
self.grasp = nn.ModuleDict()
for feature_key in self.dynamic_feature_keys:
self.grasp[feature_key] = GRASPLayer(
input_dim=embedding_dim,
static_dim=self.static_dim,
hidden_dim=hidden_dim,
**kwargs,
)
output_size = self.get_output_size()
self.fc = nn.Linear(
len(self.dynamic_feature_keys) * self.hidden_dim, output_size
)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation.
The label `kwargs[self.label_key]` is a list of labels for each
patient.
Args:
**kwargs: keyword arguments for the model. The keys must contain
all the feature keys and the label key.
Returns:
Dict[str, torch.Tensor]: A dictionary with the following keys:
- loss: a scalar tensor representing the loss.
- y_prob: a tensor representing the predicted probabilities.
- y_true: a tensor representing the true labels.
- logit: a tensor representing the logits.
- embed (optional): a tensor representing the patient
embeddings if requested.
"""
patient_emb = []
embedded = self.embedding_model(kwargs)
# Extract static features if configured
static = None
if self.static_key is not None and self.static_key in kwargs:
static = kwargs[self.static_key]
if isinstance(static, (list, tuple)):
static = torch.tensor(static, dtype=torch.float)
static = static.to(self.device)
for feature_key in self.dynamic_feature_keys:
x = embedded[feature_key]
mask = (torch.abs(x).sum(dim=-1) != 0).int()
x = self.grasp[feature_key](x, static=static, mask=mask)
patient_emb.append(x)
patient_emb = torch.cat(patient_emb, dim=1)
# (patient, label_size)
logits = self.fc(patient_emb)
# obtain y_true, loss, y_prob
y_true = kwargs[self.label_key].to(self.device)
loss = self.get_loss_function()(logits, y_true)
y_prob = self.prepare_y_prob(logits)
results = {"loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits}
if kwargs.get("embed", False):
results["embed"] = patient_emb
return results
if __name__ == "__main__":
from pyhealth.datasets import create_sample_dataset, get_dataloader
samples = [
{
"patient_id": "patient-0",
"visit_id": "visit-0",
"conditions": ["cond-33", "cond-86", "cond-80"],
"procedures": ["proc-12", "proc-45"],
"label": 1,
},
{
"patient_id": "patient-1",
"visit_id": "visit-1",
"conditions": ["cond-12", "cond-52"],
"procedures": ["proc-23"],
"label": 0,
},
]
dataset = create_sample_dataset(
samples=samples,
input_schema={
"conditions": "sequence",
"procedures": "sequence",
},
output_schema={"label": "binary"},
dataset_name="test",
)
train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
model = GRASP(
dataset=dataset,
embedding_dim=32,
hidden_dim=32,
cluster_num=2,
)
data_batch = next(iter(train_loader))
ret = model(**data_batch)
print(ret)
ret["loss"].backward()