pyhealth.models.GAN#

The GAN model (pyhealth trainer does not apply to GAN, refer to the example/ChestXray-image-generation-GAN.ipynb for examples of using GAN model).

class pyhealth.models.GAN(input_channel, input_size, hidden_dim=128, **kwargs)[source]#

Bases: Module

GAN model (take 128x128 or 64x64 or 32x32 images)

Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets.

Note

We use CNN models as the encoder and decoder layers for now.

Parameters:
  • dataset – the dataset to train the model. It is used to query certain information such as the set of all tokens.

  • feature_keys – list of keys in samples to use as features, e.g. [“conditions”, “procedures”].

  • label_key – key in samples to use as label (e.g., “drugs”).

  • mode – one of “binary”, “multiclass”, or “multilabel”.

  • embedding_dim – the embedding dimension. Default is 128.

  • hidden_dim (int) – the hidden dimension. Default is 128.

  • **kwargs – other parameters for the Deepr layer.

Examples:

discriminate(x)[source]#
Return type:

Tensor

sampling(n_samples, device)[source]#
Return type:

Tensor

generate_fake(n_samples, device)[source]#
Return type:

Tensor

training: bool#