Source code for pyhealth.models.vae

import functools
from typing import Dict, List, Optional, Tuple
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from pyhealth.datasets import BaseSignalDataset
from pyhealth.models import BaseModel, ResBlock2D


[docs]class VAE(BaseModel): """VAE model (take 128x128 or 64x64 or 32x32 images) Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." Note: We use CNN models as the encoder and decoder layers for now. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". embedding_dim: the embedding dimension. Default is 128. hidden_dim: the hidden dimension. Default is 128. **kwargs: other parameters for the Deepr layer. Examples: """ def __init__( self, dataset: BaseSignalDataset, feature_keys: List[str], label_key: str, input_channel: int, input_size: int, mode: str, hidden_dim: int = 128, **kwargs, ): super(VAE, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, mode=mode, ) self.hidden_dim = hidden_dim # encoder part if input_size == 128: self.encoder1 = nn.Sequential( ResBlock2D(input_channel, 16, 2, True, True), ResBlock2D(16, 64, 2, True, True), ResBlock2D(64, 256, 2, True, True), ) self.mu = nn.Linear(256 * 2 * 2, self.hidden_dim) # for mu self.log_std2 = nn.Linear(256 * 2 * 2, self.hidden_dim) # for log (sigma^2) self.decoder1 = nn.Sequential( nn.ConvTranspose2d(self.hidden_dim, 256, kernel_size=5, stride=2), nn.ReLU(), nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2), nn.ReLU(), nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2), nn.ReLU(), nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), nn.ReLU(), nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2), nn.Sigmoid(), ) elif input_size == 64: self.encoder1 = nn.Sequential( ResBlock2D(input_channel, 16, 2, True, True), ResBlock2D(16, 64, 2, True, True), ResBlock2D(64, 256, 2, True, True), ) self.mu = nn.Linear(256, self.hidden_dim) # for mu self.log_std2 = nn.Linear(256, self.hidden_dim) # for log (sigma^2) self.decoder1 = nn.Sequential( nn.ConvTranspose2d(self.hidden_dim, 128, kernel_size=5, stride=2), nn.ReLU(), nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2), nn.ReLU(), nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), nn.ReLU(), nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2), nn.Sigmoid(), ) elif input_size == 32: self.encoder1 = nn.Sequential( ResBlock2D(input_channel, 16, 2, True, True), ResBlock2D(16, 64, 2, True, True), # ResBlock2D(64, 256, 2, True, True), ) self.mu = nn.Linear(64 * 2 * 2, self.hidden_dim) # for mu self.log_std2 = nn.Linear(64 * 2 * 2, self.hidden_dim) # for log (sigma^2) self.decoder1 = nn.Sequential( nn.ConvTranspose2d(self.hidden_dim, 64, kernel_size=5, stride=2), nn.ReLU(), nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), nn.ReLU(), nn.ConvTranspose2d(32, input_channel, kernel_size=6, stride=2), nn.Sigmoid(), )
[docs] def encoder(self, x) -> Tuple[torch.Tensor, torch.Tensor]: h = self.encoder1(x) batch_size = h.shape[0] h = h.view(batch_size, -1) return self.mu(h), torch.sqrt(torch.exp(self.log_std2(h)))
[docs] def sampling(self, mu, std) -> torch.Tensor: # reparameterization trick eps = torch.randn_like(std) return mu + eps * std
[docs] def decoder(self, z) -> torch.Tensor: x_hat = self.decoder1(z) return x_hat
[docs] @staticmethod def loss_function(y, x, mu, std): ERR = F.binary_cross_entropy(y, x, reduction='sum') KLD = -0.5 * torch.sum(1 + torch.log(std**2) - mu**2 - std**2) return ERR + KLD
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # concat the info within one batch (batch, channel, height, width) # if the input is a list of numpy array, we need to convert it to tensor if isinstance(kwargs[self.feature_keys[0]][0], np.ndarray): x = torch.tensor( np.array(kwargs[self.feature_keys[0]]).astype("float16"), device=self.device ).float() else: x = torch.stack(kwargs[self.feature_keys[0]], dim=0).to(self.device) mu, std = self.encoder(x) z = self.sampling(mu, std) z = z.unsqueeze(2).unsqueeze(3) x_rec = self.decoder(z) loss = self.loss_function(x_rec, x, mu, std) results = { "loss": loss, "y_prob": x_rec, "y_true": x, } return results
if __name__ == "__main__": from pyhealth.datasets import SampleSignalDataset, get_dataloader from pyhealth.datasets import COVID19CXRDataset from torchvision import transforms root = "/srv/local/data/COVID-19_Radiography_Dataset" base_dataset = COVID19CXRDataset(root, dev=True, refresh_cache=False) sample_dataset = base_dataset.set_task() # the transformation automatically normalize the pixel intensity into [0, 1] transform = transforms.Compose([ transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)), # only use the first channel transforms.Resize((128, 128)), ]) def encode(sample): sample["path"] = transform(sample["path"]) return sample sample_dataset.set_transform(encode) # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(sample_dataset, batch_size=2, shuffle=True) # model model = VAE( dataset=sample_dataset, input_channel=3, input_size=128, feature_keys=["path"], label_key="path", mode="regression", hidden_dim = 256, ).to("cuda") # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()