Source code for pyhealth.models.contrawr

import functools
import math
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn

from pyhealth.datasets import BaseSignalDataset
from pyhealth.models import BaseModel


[docs]class ResBlock2D(nn.Module): """Convolutional Residual Block 2D This block stacks two convolutional layers with batch normalization, max pooling, dropout, and residual connection. Args: in_channels: number of input channels. out_channels: number of output channels. stride: stride of the convolutional layers. downsample: whether to use a downsampling residual connection. pooling: whether to use max pooling. Example: >>> import torch >>> from pyhealth.models import ResBlock2D >>> >>> model = ResBlock2D(6, 16, 1, True, True) >>> input_ = torch.randn((16, 6, 28, 150)) # (batch, channel, height, width) >>> output = model(input_) >>> output.shape torch.Size([16, 16, 14, 75]) """ def __init__( self, in_channels: int, out_channels: int, stride: int = 2, downsample: bool = True, pooling: bool = True, ): super(ResBlock2D, self).__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1 ) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ELU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) self.maxpool = nn.MaxPool2d(2, stride=2) self.downsample = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1 ), nn.BatchNorm2d(out_channels), ) self.downsampleOrNot = downsample self.pooling = pooling self.dropout = nn.Dropout(0.5)
[docs] def forward(self, x): """Forward propagation. Args: x: input tensor of shape (batch_size, in_channels, height, width). Returns: out: output tensor of shape (batch_size, out_channels, *, *). """ out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsampleOrNot: residual = self.downsample(x) out += residual if self.pooling: out = self.maxpool(out) out = self.dropout(out) return out
[docs]class ContraWR(BaseModel): """The encoder model of ContraWR (a supervised model, STFT + 2D CNN layers) Paper: Yang, Chaoqi, Danica Xiao, M. Brandon Westover, and Jimeng Sun. "Self-supervised eeg representation learning for automatic sleep staging." arXiv preprint arXiv:2110.15278 (2021). Note: We use one encoder to handle multiple channel together. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". embedding_dim: the embedding dimension. Default is 128. hidden_dim: the hidden dimension. Default is 128. **kwargs: other parameters for the Deepr layer. Examples: >>> from pyhealth.datasets import SampleSignalDataset >>> samples = [ ... { ... "record_id": "SC4001-0", ... "patient_id": "SC4001", ... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-0.pkl", ... "label": "W", ... }, ... { ... "record_id": "SC4001-1", ... "patient_id": "SC4001", ... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-1.pkl", ... "label": "R", ... } ... ] >>> dataset = SampleSignalDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import ContraWR >>> model = ContraWR( ... dataset=dataset, ... feature_keys=["signal"], # dataloader will load the signal from "epoch_path" and put it in "signal" ... label_key="label", ... mode="multiclass", ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(2.8425, device='cuda:0', grad_fn=<NllLossBackward0>), 'y_prob': tensor([[0.9345, 0.0655], [0.9482, 0.0518]], device='cuda:0', grad_fn=<SoftmaxBackward0>), 'y_true': tensor([1, 1], device='cuda:0'), 'logit': tensor([[ 0.1472, -2.5104], [2.1584, -0.7481]], device='cuda:0', grad_fn=<AddmmBackward0>) } >>> """ def __init__( self, dataset: BaseSignalDataset, feature_keys: List[str], label_key: str, mode: str, embedding_dim: int = 128, hidden_dim: int = 128, n_fft: int = 128, **kwargs, ): super(ContraWR, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, mode=mode, ) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.n_fft = n_fft # TODO: Use more tokens for <gap> for different lengths once the input has such information self.label_tokenizer = self.get_label_tokenizer() # the ContraWR encoder channels, emb_size = self.cal_encoder_stat() self.encoder = nn.Sequential( *[ ResBlock2D(channels[i], channels[i + 1], 2, True, True) for i in range(len(channels) - 1) ] ) output_size = self.get_output_size(self.label_tokenizer) # the fully connected layer self.fc = nn.Linear(emb_size, output_size)
[docs] def cal_encoder_stat(self): """obtain the convolution encoder initialization statistics Note: We show an example to illustrate the encoder statistics. input x: - torch.Size([5, 7, 3000]) after stft transform - torch.Size([5, 7, 65, 90]) we design the first CNN (out_channels = 8) - torch.Size([5, 8, 16, 22]) - here: 8 * 16 * 22 > 256, we continute the convolution we design the second CNN (out_channels = 16) - torch.Size([5, 16, 4, 5]) - here: 16 * 4 * 5 > 256, we continute the convolution we design the second CNN (out_channels = 32) - torch.Size([5, 32, 1, 1]) - here: 32 * 1 * 1, we stop the convolution output: - channels = [7, 8, 16, 32] - emb_size = 32 * 1 * 1 = 32 """ print(f"\n=== Input data statistics ===") # obtain input signal size signal_info = self.dataset.input_info["signal"] in_channels, length = signal_info["n_channels"], signal_info["length"] # input signal size (batch, n_channels, length) print(f"n_channels: {in_channels}") print(f"length: {length}") # after stft transform (batch, n_channels, freq, time_steps) freq = self.n_fft // 2 + 1 time_steps = (length - self.n_fft) // (self.n_fft // 4) + 1 print(f"=== Spectrogram statistics ===") print(f"n_channels: {in_channels}") print(f"freq_dim: {freq}") print(f"time_steps: {time_steps}") if freq < 4 or time_steps < 4: raise ValueError("The input signal is too short or n_fft is too small.") # obtain stats at each cnn layer channels = [in_channels] cur_freq_dim = freq cur_time_dim = time_steps print(f"=== Convolution Statistics ===") while (cur_freq_dim >= 4 and cur_time_dim >= 4) and ( len(channels) == 1 or cur_freq_dim * cur_time_dim * channels[-1] > 256 ): channels.append(2 ** (math.floor(np.log2(channels[-1])) + 1)) cur_freq_dim = (cur_freq_dim + 1) // 4 cur_time_dim = (cur_time_dim + 1) // 4 print( f"in_channels: {channels[-2]}, out_channels: {channels[-1]}, freq_dim: {cur_freq_dim}, time_steps: {cur_time_dim}" ) print() emb_size = cur_freq_dim * cur_time_dim * channels[-1] return channels, emb_size
[docs] def torch_stft(self, X): """torch short time fourier transform (STFT) Args: X: (batch, n_channels, length) Returns: signal: (batch, n_channels, freq, time_steps) """ signal = [] for s in range(X.shape[1]): spectral = torch.stft( X[:, s, :], n_fft=self.n_fft, hop_length=self.n_fft // 4, center=False, onesided=True, return_complex=False, ) signal.append(spectral) signal1 = torch.stack(signal)[:, :, :, :, 0].permute(1, 0, 2, 3) signal2 = torch.stack(signal)[:, :, :, :, 1].permute(1, 0, 2, 3) signal = (signal1**2 + signal2**2) ** 0.5 return signal
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation.""" # concat the info within one batch (batch, channel, length) x = torch.tensor( np.array(kwargs[self.feature_keys[0]]), device=self.device ).float() # obtain the stft spectrogram (batch, channel, freq, time step) x_spectrogram = self.torch_stft(x) # final layer embedding (batch, embedding) emb = self.encoder(x_spectrogram).view(x.shape[0], -1) # (patient, label_size) logits = self.fc(emb) # obtain y_true, loss, y_prob y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = emb return results
if __name__ == "__main__": """ test the ResBlock2D """ # import torch # input_ = torch.randn((16, 6, 3, 75)) # (batch, channel, height, width) # model = ResBlock2D(6, 16, 1, True, True) # output = model(input_) # print("input shape: ", input_.shape) # print("output.shape:", output.shape) # """ # test ContraWR # """ # from pyhealth.datasets import split_by_patient, get_dataloader # from pyhealth.trainer import Trainer # from pyhealth.datasets import SleepEDFDataset # from pyhealth.tasks import sleep_staging_sleepedf_fn # # step 1: load signal data # dataset = SleepEDFDataset( # root="/srv/local/data/SLEEPEDF/sleep-edf-database-expanded-1.0.0/sleep-cassette", # dev=True, # refresh_cache=False, # ) # # step 2: set task # sleep_staging_ds = dataset.set_task(sleep_staging_sleepedf_fn) # sleep_staging_ds.stat() # print(sleep_staging_ds.input_info) # # split dataset # train_dataset, val_dataset, test_dataset = split_by_patient( # sleep_staging_ds, [0.6, 0.2, 0.2] # ) # train_dataloader = get_dataloader(train_dataset, batch_size=5, shuffle=True) # val_dataloader = get_dataloader(val_dataset, batch_size=5, shuffle=False) # test_dataloader = get_dataloader(test_dataset, batch_size=5, shuffle=False) # print( # "loader size: train/val/test", # len(train_dataset), # len(val_dataset), # len(test_dataset), # ) # batch = next(iter(train_dataloader)) # # step 3: define model # model = ContraWR( # sleep_staging_ds, # feature_keys=["signal"], # label_key="label", # mode="multiclass", # n_fft=128, # ) # result = model(**batch) # print(result) """ test ContraWR 2 """ from pyhealth.datasets import SampleSignalDataset, get_dataloader samples = [ { "record_id": "SC4001-0", "patient_id": "SC4001", "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-0.pkl", "label": "W", }, { "record_id": "SC4001-0", "patient_id": "SC4001", "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-1.pkl", "label": "R", }, ] # dataset dataset = SampleSignalDataset(samples=samples, dataset_name="test") # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) # model model = ContraWR( dataset=dataset, feature_keys=["signal"], label_key="label", mode="multiclass", ).to("cuda:0") # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()