pyhealth.models.GRASP#

The separate callable GRASPLayer and the complete GRASP model.

class pyhealth.models.GRASPLayer(input_dim, static_dim=0, hidden_dim=128, cluster_num=2, dropout=0.5, block='ConCare')[source]#

Bases: Module

GRASPLayer layer.

Paper: Liantao Ma et al. GRASP: generic framework for health status representation learning based on incorporating knowledge from similar patients. AAAI 2021.

This layer is used in the GRASP model. But it can also be used as a standalone layer.

Parameters:
  • input_dim (int) – dynamic feature size.

  • static_dim (int) – static feature size, if 0, then no static feature is used.

  • hidden_dim (int) – hidden dimension of the GRASP layer, default 128.

  • cluster_num (int) – number of clusters, default 12. The cluster_num should be no more than the number of samples.

  • dropout (int) – dropout rate, default 0.5.

  • block (str) – the backbone model used in the GRASP layer (‘ConCare’, ‘LSTM’ or ‘GRU’), default ‘ConCare’.

Examples

>>> from pyhealth.models import GRASPLayer
>>> input = torch.randn(3, 128, 64)  # [batch size, sequence len, feature_size]
>>> layer = GRASPLayer(64, cluster_num=2)
>>> c = layer(input)
>>> c.shape
torch.Size([3, 128])
sample_gumbel(shape, eps=1e-20)[source]#
gumbel_softmax_sample(logits, temperature, device)[source]#
gumbel_softmax(logits, temperature, device, hard=False)[source]#

ST-gumple-softmax input: [, n_class] return: flatten –> [, n_class] an one-hot vector

grasp_encoder(input, static=None, mask=None)[source]#
forward(x, static=None, mask=None)[source]#

Forward propagation.

Parameters:
  • x (tensor) – a tensor of shape [batch size, sequence len, input_dim].

  • static (Optional[tensor]) – a tensor of shape [batch size, static_dim].

  • mask (Optional[tensor]) – an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid.

Returns:

a tensor of shape [batch size, fusion_dim] representing the

patient embedding.

Return type:

output

training: bool#
class pyhealth.models.GRASP(dataset, feature_keys, label_key, mode, use_embedding, static_key=None, embedding_dim=128, hidden_dim=128, **kwargs)[source]#

Bases: BaseModel

GRASP model.

Paper: Liantao Ma et al. GRASP: generic framework for health status representation learning based on incorporating knowledge from similar patients. AAAI 2021.

Note

We use separate GRASP layers for different feature_keys. Currently, we automatically support different input formats:

  • code based input (need to use the embedding table later)

  • float/int based value input

We follow the current convention for the GRASP model:
  • case 1. [code1, code2, code3, …]
    • we will assume the code follows the order; our model will encode

    each code into a vector and apply GRASP on the code level

  • case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], …]
    • we will assume the inner bracket follows the order; our model first

    use the embedding table to encode each code into a vector and then use average/mean pooling to get one vector for one inner bracket; then use GRASP one the braket level

  • case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], …]
    • this case only makes sense when each inner bracket has the same length;

    we assume each dimension has the same meaning; we run GRASP directly on the inner bracket level, similar to case 1 after embedding table

  • case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], …]
    • this case only makes sense when each inner bracket has the same length;

    we assume each dimension has the same meaning; we run GRASP directly on the inner bracket level, similar to case 2 after embedding table

Parameters:
  • dataset (SampleEHRDataset) – the dataset to train the model. It is used to query certain information such as the set of all tokens.

  • feature_keys (List[str]) – list of keys in samples to use as features, e.g. [“conditions”, “procedures”].

  • label_key (str) – key in samples to use as label (e.g., “drugs”).

  • mode (str) – one of “binary”, “multiclass”, or “multilabel”.

  • static_keys – the key in samples to use as static features, e.g. “demographics”. Default is None. we only support numerical static features.

  • use_embedding (List[bool]) – list of bools indicating whether to use embedding for each feature type, e.g. [True, False].

  • embedding_dim (int) – the embedding dimension. Default is 128.

  • hidden_dim (int) – the hidden dimension of the GRASP layer. Default is 128.

  • cluster_num – the number of clusters. Default is 10. Note that batch size should be greater than cluster_num.

  • **kwargs – other parameters for the GRASP layer.

Examples

>>> from pyhealth.datasets import SampleEHRDataset
>>> samples = [
...         {
...             "patient_id": "patient-0",
...             "visit_id": "visit-0",
...             "list_codes": ["505800458", "50580045810", "50580045811"],  # NDC
...             "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]],
...             "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]],  # ATC-4
...             "list_list_vectors": [
...                 [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]],
...                 [[7.7, 8.5, 9.4]],
...             ],
...             "demographic": [0.0, 2.0, 1.5],
...             "label": 1,
...         },
...         {
...             "patient_id": "patient-0",
...             "visit_id": "visit-1",
...             "list_codes": [
...                 "55154191800",
...                 "551541928",
...                 "55154192800",
...                 "705182798",
...                 "70518279800",
...             ],
...             "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]],
...             "list_list_codes": [["A04A", "B035", "C129"]],
...             "list_list_vectors": [
...                 [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]],
...             ],
...             "demographic": [0.0, 2.0, 1.5],
...             "label": 0,
...         },
...     ]
>>> dataset = SampleEHRDataset(samples=samples, dataset_name="test")
>>>
>>> from pyhealth.models import GRASP
>>> model = GRASP(
...         dataset=dataset,
...         feature_keys=[
...             "list_codes",
...             "list_vectors",
...             "list_list_codes",
...             "list_list_vectors",
...         ],
...         label_key="label",
...         static_key="demographic",
...         use_embedding=[True, False, True, False],
...         mode="binary"
...     )
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
>>> data_batch = next(iter(train_loader))
>>>
>>> ret = model(**data_batch)
>>> print(ret)
{
    'loss': tensor(0.6896, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
    'y_prob': tensor([[0.4983],
                [0.4947]], grad_fn=<SigmoidBackward0>),
    'y_true': tensor([[1.],
                [0.]]),
    'logit': tensor([[-0.0070],
                [-0.0213]], grad_fn=<AddmmBackward0>)
}
>>>
forward(**kwargs)[source]#

Forward propagation.

The label kwargs[self.label_key] is a list of labels for each patient.

Parameters:

**kwargs – keyword arguments for the model. The keys must contain all the feature keys and the label key.

Returns:

loss: a scalar tensor representing the final loss. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels.

Return type:

A dictionary with the following keys

training: bool#