pyhealth.models.GAN#
The GAN model (pyhealth trainer does not apply to GAN, refer to the example/ChestXray-image-generation-GAN.ipynb for examples of using GAN model).
- class pyhealth.models.GAN(input_channel, input_size, hidden_dim=128, **kwargs)[source]#
Bases:
Module
GAN model (take 128x128 or 64x64 or 32x32 images)
Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets.
Note
We use CNN models as the encoder and decoder layers for now.
- Parameters:
dataset – the dataset to train the model. It is used to query certain information such as the set of all tokens.
feature_keys – list of keys in samples to use as features, e.g. [“conditions”, “procedures”].
label_key – key in samples to use as label (e.g., “drugs”).
mode – one of “binary”, “multiclass”, or “multilabel”.
embedding_dim – the embedding dimension. Default is 128.
hidden_dim (
int
) – the hidden dimension. Default is 128.**kwargs – other parameters for the Deepr layer.
Examples: