Source code for pyhealth.models.contrawr

import math
from typing import Dict

import numpy as np
import torch
import torch.nn as nn

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel


[docs]class ResBlock2D(nn.Module): """Convolutional Residual Block 2D This block stacks two convolutional layers with batch normalization, max pooling, dropout, and residual connection. Args: in_channels: number of input channels. out_channels: number of output channels. stride: stride of the convolutional layers. downsample: whether to use a downsampling residual connection. pooling: whether to use max pooling. Example: >>> import torch >>> from pyhealth.models import ResBlock2D >>> >>> model = ResBlock2D(6, 16, 1, True, True) >>> input_ = torch.randn((16, 6, 28, 150)) # (batch, channel, height, width) >>> output = model(input_) >>> output.shape torch.Size([16, 16, 14, 75]) """ def __init__( self, in_channels: int, out_channels: int, stride: int = 2, downsample: bool = True, pooling: bool = True, ): super(ResBlock2D, self).__init__() self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1 ) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ELU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) self.maxpool = nn.MaxPool2d(2, stride=2) self.downsampler = nn.Sequential( nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1 ), nn.BatchNorm2d(out_channels), ) self.downsample = downsample self.pooling = pooling self.dropout = nn.Dropout(0.5)
[docs] def forward(self, x): """Forward propagation. Args: x: input tensor of shape (batch_size, in_channels, height, width). Returns: out: output tensor of shape (batch_size, out_channels, *, *). """ out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample: residual = self.downsampler(x) out += residual if self.pooling: out = self.maxpool(out) out = self.dropout(out) return out
[docs]class ContraWR(BaseModel): """The encoder model of ContraWR (a supervised model, STFT + 2D CNN layers) Paper: Yang, Chaoqi, Danica Xiao, M. Brandon Westover, and Jimeng Sun. "Self-supervised eeg representation learning for automatic sleep staging." arXiv preprint arXiv:2110.15278 (2021). Note: We use one encoder to handle multiple channel together. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. embedding_dim: the embedding dimension. Default is 128. hidden_dim: the hidden dimension. Default is 128. n_fft: the number of FFT points for STFT. Default is 128. Examples: >>> from pyhealth.datasets import SampleSignalDataset >>> samples = [ ... { ... "record_id": "SC4001-0", ... "patient_id": "SC4001", ... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-0.pkl", ... "label": "W", ... }, ... { ... "record_id": "SC4001-1", ... "patient_id": "SC4001", ... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-1.pkl", ... "label": "R", ... } ... ] >>> dataset = SampleSignalDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import ContraWR >>> model = ContraWR( ... dataset=dataset, ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(2.8425, device='cuda:0', grad_fn=<NllLossBackward0>), 'y_prob': tensor([[0.9345, 0.0655], [0.9482, 0.0518]], device='cuda:0', grad_fn=<SoftmaxBackward0>), 'y_true': tensor([1, 1], device='cuda:0'), 'logit': tensor([[ 0.1472, -2.5104], [2.1584, -0.7481]], device='cuda:0', grad_fn=<AddmmBackward0>) } >>> """ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, hidden_dim: int = 128, n_fft: int = 128, ): super(ContraWR, self).__init__(dataset=dataset) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.n_fft = n_fft assert len(self.label_keys) == 1, ( "Only one label key is supported if ContraWR is initialized" ) # ContraWR computes its own STFT internally — drop any precomputed # auxiliary keys (e.g. "stft") that TFMTokenizer adds to input_schema. self.feature_keys = [k for k in self.feature_keys if k != "stft"] assert len(self.feature_keys) == 1, ( "Only one feature key is supported if ContraWR is initialized" ) # the ContraWR encoder channels, emb_size = self.determine_encoder_params() self.encoder = nn.Sequential( *[ ResBlock2D(channels[i], channels[i + 1], 2, True, True) for i in range(len(channels) - 1) ] ) output_size = self.get_output_size() # the fully connected layer self.fc = nn.Linear(emb_size, output_size) def _determine_input_channels_length(self) -> int: for sample in self.dataset: if self.feature_keys[0] not in sample: continue if len(sample[self.feature_keys[0]].shape) == 1: return 1, sample[self.feature_keys[0]].shape[0] elif len(sample[self.feature_keys[0]].shape) == 2: return sample[self.feature_keys[0]].shape[0], sample[ self.feature_keys[0] ].shape[1] else: raise ValueError( f"Invalid shape for feature key {self.feature_keys[0]}: {sample[self.feature_keys[0]].shape}" ) raise ValueError( f"Unable to infer input channels and length from dataset for feature key {self.feature_keys[0]}" )
[docs] def determine_encoder_params(self): """obtain the convolution encoder parameters based on input signal size Note: We show an example to illustrate the design process here. assume the input signal size is (batch = 5, n_channels = 7, length = 3000) input x: - torch.Size([5, 7, 3000]) after stft transform - torch.Size([5, 7, 65, 90]) we design the first CNN (out_channels = 8) - torch.Size([5, 8, 16, 22]) - here: 8 * 16 * 22 > 256, we continute the convolution we design the second CNN (out_channels = 16) - torch.Size([5, 16, 4, 5]) - here: 16 * 4 * 5 > 256, we continute the convolution we design the second CNN (out_channels = 32) - torch.Size([5, 32, 1, 1]) - here: 32 * 1 * 1, we stop the convolution output: - channels = [7, 8, 16, 32] - emb_size = 32 * 1 * 1 = 32 """ print("\n=== Input data dimensions ===") in_channels, length = self._determine_input_channels_length() print(f"n_channels: {in_channels}") print(f"length: {length}") freq = self.n_fft // 2 + 1 time_steps = (length - self.n_fft) // (self.n_fft // 4) + 1 print("=== Spectrogram parameters ===") print(f"n_channels: {in_channels}") print(f"freq_dim: {freq}") print(f"time_steps: {time_steps}") if freq < 4 or time_steps < 4: raise ValueError("The input signal is too short or n_fft is too small.") # obtain stats at each cnn layer channels = [in_channels] cur_freq_dim = freq cur_time_dim = time_steps print("=== Convolution Parameters ===") while (cur_freq_dim >= 4 and cur_time_dim >= 4) and ( len(channels) == 1 or cur_freq_dim * cur_time_dim * channels[-1] > 256 ): channels.append(2 ** (math.floor(np.log2(channels[-1])) + 1)) cur_freq_dim = (cur_freq_dim + 1) // 4 cur_time_dim = (cur_time_dim + 1) // 4 print( f"in_channels: {channels[-2]}, out_channels: {channels[-1]}, freq_dim: {cur_freq_dim}, time_steps: {cur_time_dim}" ) print() emb_size = cur_freq_dim * cur_time_dim * channels[-1] return channels, emb_size
[docs] def torch_stft(self, X): """torch short time fourier transform (STFT) Args: X: (batch, n_channels, length) Returns: signal: (batch, n_channels, freq, time_steps) """ signal = [] for s in range(X.shape[1]): spectral = torch.stft( X[:, s, :], n_fft=self.n_fft, hop_length=self.n_fft // 4, center=False, onesided=True, return_complex=True, window=torch.hann_window(self.n_fft).to(X.device), ) signal.append(torch.view_as_real(spectral)) signal1 = torch.stack(signal)[:, :, :, :, 0].permute(1, 0, 2, 3) signal2 = torch.stack(signal)[:, :, :, :, 1].permute(1, 0, 2, 3) signal = (signal1**2 + signal2**2) ** 0.5 return signal
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation.""" # concat the info within one batch (batch, channel, length) x = kwargs[self.feature_keys[0]].to(self.device) # obtain the stft spectrogram (batch, channel, freq, time step) x_spectrogram = self.torch_stft(x) # final layer embedding (batch, embedding) emb = self.encoder(x_spectrogram).view(x.shape[0], -1) # (patient, label_size) logits = self.fc(emb) # obtain y_true, loss, y_prob y_true = kwargs[self.label_keys[0]].to(self.device) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = emb return results