pyhealth.models.GAMENet#
The separate callable GAMENetLayer and the complete GAMENet model.
- class pyhealth.models.GAMENetLayer(hidden_size, ehr_adj, ddi_adj, dropout=0.5)[source]#
Bases:
Module
GAMENet layer.
Paper: Junyuan Shang et al. GAMENet: Graph Augmented MEmory Networks for Recommending Medication Combination AAAI 2019.
This layer is used in the GAMENet model. But it can also be used as a standalone layer.
- Parameters:
Examples
>>> from pyhealth.models import GAMENetLayer >>> queries = torch.randn(3, 5, 32) # [patient, visit, hidden_size] >>> prev_drugs = torch.randint(0, 2, (3, 4, 50)).float() >>> curr_drugs = torch.randint(0, 2, (3, 50)).float() >>> ehr_adj = torch.randint(0, 2, (50, 50)).float() >>> ddi_adj = torch.randint(0, 2, (50, 50)).float() >>> layer = GAMENetLayer(32, ehr_adj, ddi_adj) >>> loss, y_prob = layer(queries, prev_drugs, curr_drugs) >>> loss.shape torch.Size([]) >>> y_prob.shape torch.Size([3, 50])
- forward(queries, prev_drugs, curr_drugs, mask=None)[source]#
Forward propagation.
- Parameters:
queries (
tensor
) – query tensor of shape [patient, visit, hidden_size].prev_drugs (
tensor
) – multihot tensor indicating drug usage in all previous visits of shape [patient, visit - 1, num_drugs].curr_drugs (
tensor
) – multihot tensor indicating drug usage in the current visit of shape [patient, num_drugs].mask (
Optional
[tensor
]) – an optional mask tensor of shape [patient, visit] where 1 indicates valid visits and 0 indicates invalid visits.
- Returns:
a scalar tensor representing the loss. y_prob: a tensor of shape [patient, num_labels] representing
the probability of each drug.
- Return type:
loss
- class pyhealth.models.GAMENet(dataset, embedding_dim=128, hidden_dim=128, num_layers=1, dropout=0.5, **kwargs)[source]#
Bases:
BaseModel
GAMENet model.
Paper: Junyuan Shang et al. GAMENet: Graph Augmented MEmory Networks for Recommending Medication Combination AAAI 2019.
Note
This model is only for medication prediction which takes conditions and procedures as feature_keys, and drugs as label_key. It only operates on the visit level. Thus, we have disable the feature_keys, label_key, and mode arguments.
Note
This model only accepts ATC level 3 as medication codes.
- Parameters:
dataset (
SampleEHRDataset
) – the dataset to train the model. It is used to query certain information such as the set of all tokens.embedding_dim (
int
) – the embedding dimension. Default is 128.hidden_dim (
int
) – the hidden dimension. Default is 128.num_layers (
int
) – the number of layers used in RNN. Default is 1.dropout (
float
) – the dropout rate. Default is 0.5.**kwargs – other parameters for the GAMENet layer.
- forward(conditions, procedures, drugs_hist, drugs, **kwargs)[source]#
Forward propagation.
- Parameters:
conditions (
List
[List
[List
[str
]]]) – a nested list in three levels [patient, visit, condition].procedures (
List
[List
[List
[str
]]]) – a nested list in three levels [patient, visit, procedure].drugs_hist (
List
[List
[List
[str
]]]) – a nested list in three levels [patient, visit, drug], up to visit (N-1)drugs (
List
[List
[str
]]) – a nested list in two levels [patient, drug], at visit N
- Returns:
loss: a scalar tensor representing the loss. y_prob: a tensor of shape [patient, visit, num_labels] representing
the probability of each drug.
- y_true: a tensor of shape [patient, visit, num_labels] representing
the ground truth of each drug.
- Return type:
A dictionary with the following keys