import functools
from typing import Dict, List, Optional, Tuple
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pyhealth.datasets import BaseSignalDataset
from pyhealth.models import BaseModel, ResBlock2D
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
[docs]class GAN(nn.Module):
"""GAN model (take 128x128 or 64x64 or 32x32 images)
Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley,
Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets.
Note:
We use CNN models as the encoder and decoder layers for now.
Args:
dataset: the dataset to train the model. It is used to query certain
information such as the set of all tokens.
feature_keys: list of keys in samples to use as features,
e.g. ["conditions", "procedures"].
label_key: key in samples to use as label (e.g., "drugs").
mode: one of "binary", "multiclass", or "multilabel".
embedding_dim: the embedding dimension. Default is 128.
hidden_dim: the hidden dimension. Default is 128.
**kwargs: other parameters for the Deepr layer.
Examples:
"""
def __init__(
self,
input_channel: int,
input_size: int,
hidden_dim: int = 128,
**kwargs,
):
super(GAN, self).__init__()
self.hidden_dim = hidden_dim
# encoder part
if input_size == 128:
self.discriminator = nn.Sequential(
ResBlock2D(input_channel, 16, 2, True, True),
ResBlock2D(16, 64, 2, True, True),
ResBlock2D(64, 256, 2, True, True),
Flatten(),
nn.Linear(256 * 2 * 2, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, 1),
nn.Sigmoid(),
)
self.generator = nn.Sequential(
nn.ConvTranspose2d(self.hidden_dim, 256, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2),
nn.Sigmoid(),
)
elif input_size == 64:
self.discriminator = nn.Sequential(
ResBlock2D(input_channel, 16, 2, True, True),
ResBlock2D(16, 64, 2, True, True),
ResBlock2D(64, 256, 2, True, True),
Flatten(),
nn.Linear(256, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, 1),
nn.Sigmoid(),
)
self.generator = nn.Sequential(
nn.ConvTranspose2d(self.hidden_dim, 128, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2),
nn.Sigmoid(),
)
elif input_size == 32:
self.discriminator = nn.Sequential(
ResBlock2D(input_channel, 16, 2, True, True),
ResBlock2D(16, 64, 2, True, True),
Flatten(),
nn.Linear(64 * 2 * 2, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, 1),
nn.Sigmoid(),
)
self.generator = nn.Sequential(
nn.ConvTranspose2d(self.hidden_dim, 64, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2),
nn.Sigmoid(),
)
[docs] def discriminate(self, x) -> torch.Tensor:
y = self.discriminator(x)
return y
[docs] def sampling(self, n_samples, device) -> torch.Tensor:
eps = torch.randn(n_samples, self.hidden_dim, 1, 1).to(device)
return eps
[docs] def generate_fake(self, n_samples, device) -> torch.Tensor:
eps = self.sampling(n_samples, device)
fake_images = self.generator(eps)
return fake_images
if __name__ == "__main__":
""" simple test """
model = GAN(
input_channel=3,
input_size=128,
hidden_dim = 256,
)
# test generation
n_samples = 10
device = "cpu"
fake_images = model.generate_fake(n_samples, device)
print (fake_images.shape)
# test discriminate
y = model.discriminate(fake_images)
print (y.shape)