pyhealth.models.RNN#
The separate callable RNNLayer and the complete RNN model.
- class pyhealth.models.RNNLayer(input_size, hidden_size, rnn_type='GRU', num_layers=1, dropout=0.5, bidirectional=False)[source]#
Bases:
ModuleRecurrent neural network layer.
This layer wraps the PyTorch RNN layer with masking and dropout support. It is used in the RNN model. But it can also be used as a standalone layer.
- Parameters:
input_size (
int) – input feature size.hidden_size (
int) – hidden feature size.rnn_type (
str) – type of rnn, one of “RNN”, “LSTM”, “GRU”. Default is “GRU”.num_layers (
int) – number of recurrent layers. Default is 1.dropout (
float) – dropout rate. If non-zero, introduces a Dropout layer before each RNN layer. Default is 0.5.bidirectional (
bool) – whether to use bidirectional recurrent layers. If True, a fully-connected layer is applied to the concatenation of the forward and backward hidden states to reduce the dimension to hidden_size. Default is False.
Examples
>>> from pyhealth.models import RNNLayer >>> input = torch.randn(3, 128, 5) # [batch size, sequence len, input_size] >>> layer = RNNLayer(5, 64) >>> outputs, last_outputs = layer(input) >>> outputs.shape torch.Size([3, 128, 64]) >>> last_outputs.shape torch.Size([3, 64])
- forward(x, mask=None)[source]#
Forward propagation.
- Parameters:
x (
tensor) – a tensor of shape [batch size, sequence len, input size].mask (
Optional[tensor]) – an optional tensor of shape [batch size, sequence len], where 1 indicates valid and 0 indicates invalid.
- Returns:
- a tensor of shape [batch size, sequence len, hidden size],
containing the output features for each time step.
- last_outputs: a tensor of shape [batch size, hidden size], containing
the output features for the last time step.
- Return type:
outputs
- class pyhealth.models.RNN(dataset, feature_keys, label_key, mode, pretrained_emb=None, embedding_dim=128, hidden_dim=128, **kwargs)[source]#
Bases:
BaseModelRecurrent neural network model.
This model applies a separate RNN layer for each feature, and then concatenates the final hidden states of each RNN layer. The concatenated hidden states are then fed into a fully connected layer to make predictions.
Note
We use separate rnn layers for different feature_keys. Currently, we automatically support different input formats:
code based input (need to use the embedding table later)
float/int based value input
- We follow the current convention for the rnn model:
- case 1. [code1, code2, code3, …]
we will assume the code follows the order; our model will encode
each code into a vector and apply rnn on the code level
- case 2. [[code1, code2]] or [[code1, code2], [code3, code4, code5], …]
we will assume the inner bracket follows the order; our model first
use the embedding table to encode each code into a vector and then use average/mean pooling to get one vector for one inner bracket; then use rnn one the braket level
- case 3. [[1.5, 2.0, 0.0]] or [[1.5, 2.0, 0.0], [8, 1.2, 4.5], …]
this case only makes sense when each inner bracket has the same length;
we assume each dimension has the same meaning; we run rnn directly on the inner bracket level, similar to case 1 after embedding table
- case 4. [[[1.5, 2.0, 0.0]]] or [[[1.5, 2.0, 0.0], [8, 1.2, 4.5]], …]
this case only makes sense when each inner bracket has the same length;
we assume each dimension has the same meaning; we run rnn directly on the inner bracket level, similar to case 2 after embedding table
- Parameters:
dataset (
SampleEHRDataset) – the dataset to train the model. It is used to query certain information such as the set of all tokens.feature_keys (
List[str]) – list of keys in samples to use as features, e.g. [“conditions”, “procedures”].label_key (
str) – key in samples to use as label (e.g., “drugs”).mode (
str) – one of “binary”, “multiclass”, or “multilabel”.embedding_dim (
int) – the embedding dimension. Default is 128.hidden_dim (
int) – the hidden dimension. Default is 128.**kwargs – other parameters for the RNN layer.
Examples
>>> from pyhealth.datasets import SampleEHRDataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 ... "list_list_vectors": [ ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], ... [[7.7, 8.5, 9.4]], ... ], ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "list_codes": [ ... "55154191800", ... "551541928", ... "55154192800", ... "705182798", ... "70518279800", ... ], ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7], [4.5, 5.9, 1.7]], ... "list_list_codes": [["A04A", "B035", "C129"]], ... "list_list_vectors": [ ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6], [7.7, 8.4, 1.3], [7.7, 8.4, 1.3]], ... ], ... "label": 0, ... }, ... ] >>> dataset = SampleEHRDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import RNN >>> model = RNN( ... dataset=dataset, ... feature_keys=[ ... "list_codes", ... "list_vectors", ... "list_list_codes", ... "list_list_vectors", ... ], ... label_key="label", ... mode="binary", ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(0.8056, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'y_prob': tensor([[0.5906], [0.6620]], grad_fn=<SigmoidBackward0>), 'y_true': tensor([[1.], [0.]]), 'logit': tensor([[0.3666], [0.6721]], grad_fn=<AddmmBackward0>) } >>>
- forward(**kwargs)[source]#
Forward propagation.
The label kwargs[self.label_key] is a list of labels for each patient.
- Parameters:
**kwargs – keyword arguments for the model. The keys must contain all the feature keys and the label key.
- Returns:
loss: a scalar tensor representing the loss. y_prob: a tensor representing the predicted probabilities. y_true: a tensor representing the true labels.
- Return type:
A dictionary with the following keys