pyhealth.models.VAE#
The VAE model (treated as a regression task).
- class pyhealth.models.VAE(dataset, feature_keys, label_key, input_channel, input_size, mode, hidden_dim=128, **kwargs)[source]#
Bases:
BaseModel
VAE model (take 128x128 or 64x64 or 32x32 images)
Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.”
Note
We use CNN models as the encoder and decoder layers for now.
- Parameters:
dataset (
BaseSignalDataset
) – the dataset to train the model. It is used to query certain information such as the set of all tokens.feature_keys (
List
[str
]) – list of keys in samples to use as features, e.g. [“conditions”, “procedures”].label_key (
str
) – key in samples to use as label (e.g., “drugs”).mode (
str
) – one of “binary”, “multiclass”, or “multilabel”.embedding_dim – the embedding dimension. Default is 128.
hidden_dim (
int
) – the hidden dimension. Default is 128.**kwargs – other parameters for the Deepr layer.
Examples:
- forward(**kwargs)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.