import functools
import math
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from pyhealth.datasets import BaseSignalDataset
from pyhealth.models import BaseModel
[docs]class ResBlock2D(nn.Module):
"""Convolutional Residual Block 2D
This block stacks two convolutional layers with batch normalization,
max pooling, dropout, and residual connection.
Args:
in_channels: number of input channels.
out_channels: number of output channels.
stride: stride of the convolutional layers.
downsample: whether to use a downsampling residual connection.
pooling: whether to use max pooling.
Example:
>>> import torch
>>> from pyhealth.models import ResBlock2D
>>>
>>> model = ResBlock2D(6, 16, 1, True, True)
>>> input_ = torch.randn((16, 6, 28, 150)) # (batch, channel, height, width)
>>> output = model(input_)
>>> output.shape
torch.Size([16, 16, 14, 75])
"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 2,
downsample: bool = True,
pooling: bool = True,
):
super(ResBlock2D, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ELU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.maxpool = nn.MaxPool2d(2, stride=2)
self.downsample = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1
),
nn.BatchNorm2d(out_channels),
)
self.downsampleOrNot = downsample
self.pooling = pooling
self.dropout = nn.Dropout(0.5)
[docs] def forward(self, x):
"""Forward propagation.
Args:
x: input tensor of shape (batch_size, in_channels, height, width).
Returns:
out: output tensor of shape (batch_size, out_channels, *, *).
"""
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsampleOrNot:
residual = self.downsample(x)
out += residual
if self.pooling:
out = self.maxpool(out)
out = self.dropout(out)
return out
[docs]class ContraWR(BaseModel):
"""The encoder model of ContraWR (a supervised model, STFT + 2D CNN layers)
Paper: Yang, Chaoqi, Danica Xiao, M. Brandon Westover, and Jimeng Sun.
"Self-supervised eeg representation learning for automatic sleep staging."
arXiv preprint arXiv:2110.15278 (2021).
Note:
We use one encoder to handle multiple channel together.
Args:
dataset: the dataset to train the model. It is used to query certain
information such as the set of all tokens.
feature_keys: list of keys in samples to use as features,
e.g. ["conditions", "procedures"].
label_key: key in samples to use as label (e.g., "drugs").
mode: one of "binary", "multiclass", or "multilabel".
embedding_dim: the embedding dimension. Default is 128.
hidden_dim: the hidden dimension. Default is 128.
**kwargs: other parameters for the Deepr layer.
Examples:
>>> from pyhealth.datasets import SampleSignalDataset
>>> samples = [
... {
... "record_id": "SC4001-0",
... "patient_id": "SC4001",
... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-0.pkl",
... "label": "W",
... },
... {
... "record_id": "SC4001-1",
... "patient_id": "SC4001",
... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-1.pkl",
... "label": "R",
... }
... ]
>>> dataset = SampleSignalDataset(samples=samples, dataset_name="test")
>>>
>>> from pyhealth.models import ContraWR
>>> model = ContraWR(
... dataset=dataset,
... feature_keys=["signal"], # dataloader will load the signal from "epoch_path" and put it in "signal"
... label_key="label",
... mode="multiclass",
... )
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
>>> data_batch = next(iter(train_loader))
>>>
>>> ret = model(**data_batch)
>>> print(ret)
{
'loss': tensor(2.8425, device='cuda:0', grad_fn=<NllLossBackward0>),
'y_prob': tensor([[0.9345, 0.0655],
[0.9482, 0.0518]], device='cuda:0', grad_fn=<SoftmaxBackward0>),
'y_true': tensor([1, 1], device='cuda:0'),
'logit': tensor([[ 0.1472, -2.5104],
[2.1584, -0.7481]], device='cuda:0', grad_fn=<AddmmBackward0>)
}
>>>
"""
def __init__(
self,
dataset: BaseSignalDataset,
feature_keys: List[str],
label_key: str,
mode: str,
embedding_dim: int = 128,
hidden_dim: int = 128,
n_fft: int = 128,
**kwargs,
):
super(ContraWR, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
mode=mode,
)
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.n_fft = n_fft
# TODO: Use more tokens for <gap> for different lengths once the input has such information
self.label_tokenizer = self.get_label_tokenizer()
# the ContraWR encoder
channels, emb_size = self.cal_encoder_stat()
self.encoder = nn.Sequential(
*[
ResBlock2D(channels[i], channels[i + 1], 2, True, True)
for i in range(len(channels) - 1)
]
)
output_size = self.get_output_size(self.label_tokenizer)
# the fully connected layer
self.fc = nn.Linear(emb_size, output_size)
[docs] def cal_encoder_stat(self):
"""obtain the convolution encoder initialization statistics
Note:
We show an example to illustrate the encoder statistics.
input x:
- torch.Size([5, 7, 3000])
after stft transform
- torch.Size([5, 7, 65, 90])
we design the first CNN (out_channels = 8)
- torch.Size([5, 8, 16, 22])
- here: 8 * 16 * 22 > 256, we continute the convolution
we design the second CNN (out_channels = 16)
- torch.Size([5, 16, 4, 5])
- here: 16 * 4 * 5 > 256, we continute the convolution
we design the second CNN (out_channels = 32)
- torch.Size([5, 32, 1, 1])
- here: 32 * 1 * 1, we stop the convolution
output:
- channels = [7, 8, 16, 32]
- emb_size = 32 * 1 * 1 = 32
"""
print(f"\n=== Input data statistics ===")
# obtain input signal size
signal_info = self.dataset.input_info["signal"]
in_channels, length = signal_info["n_channels"], signal_info["length"]
# input signal size (batch, n_channels, length)
print(f"n_channels: {in_channels}")
print(f"length: {length}")
# after stft transform (batch, n_channels, freq, time_steps)
freq = self.n_fft // 2 + 1
time_steps = (length - self.n_fft) // (self.n_fft // 4) + 1
print(f"=== Spectrogram statistics ===")
print(f"n_channels: {in_channels}")
print(f"freq_dim: {freq}")
print(f"time_steps: {time_steps}")
if freq < 4 or time_steps < 4:
raise ValueError("The input signal is too short or n_fft is too small.")
# obtain stats at each cnn layer
channels = [in_channels]
cur_freq_dim = freq
cur_time_dim = time_steps
print(f"=== Convolution Statistics ===")
while (cur_freq_dim >= 4 and cur_time_dim >= 4) and (
len(channels) == 1 or cur_freq_dim * cur_time_dim * channels[-1] > 256
):
channels.append(2 ** (math.floor(np.log2(channels[-1])) + 1))
cur_freq_dim = (cur_freq_dim + 1) // 4
cur_time_dim = (cur_time_dim + 1) // 4
print(
f"in_channels: {channels[-2]}, out_channels: {channels[-1]}, freq_dim: {cur_freq_dim}, time_steps: {cur_time_dim}"
)
print()
emb_size = cur_freq_dim * cur_time_dim * channels[-1]
return channels, emb_size
[docs] def torch_stft(self, X):
"""torch short time fourier transform (STFT)
Args:
X: (batch, n_channels, length)
Returns:
signal: (batch, n_channels, freq, time_steps)
"""
signal = []
for s in range(X.shape[1]):
spectral = torch.stft(
X[:, s, :],
n_fft=self.n_fft,
hop_length=self.n_fft // 4,
center=False,
onesided=True,
return_complex=False,
)
signal.append(spectral)
signal1 = torch.stack(signal)[:, :, :, :, 0].permute(1, 0, 2, 3)
signal2 = torch.stack(signal)[:, :, :, :, 1].permute(1, 0, 2, 3)
signal = (signal1**2 + signal2**2) ** 0.5
return signal
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation."""
# concat the info within one batch (batch, channel, length)
x = torch.tensor(
np.array(kwargs[self.feature_keys[0]]), device=self.device
).float()
# obtain the stft spectrogram (batch, channel, freq, time step)
x_spectrogram = self.torch_stft(x)
# final layer embedding (batch, embedding)
emb = self.encoder(x_spectrogram).view(x.shape[0], -1)
# (patient, label_size)
logits = self.fc(emb)
# obtain y_true, loss, y_prob
y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
loss = self.get_loss_function()(logits, y_true)
y_prob = self.prepare_y_prob(logits)
results = {
"loss": loss,
"y_prob": y_prob,
"y_true": y_true,
"logit": logits,
}
if kwargs.get("embed", False):
results["embed"] = emb
return results
if __name__ == "__main__":
"""
test the ResBlock2D
"""
# import torch
# input_ = torch.randn((16, 6, 3, 75)) # (batch, channel, height, width)
# model = ResBlock2D(6, 16, 1, True, True)
# output = model(input_)
# print("input shape: ", input_.shape)
# print("output.shape:", output.shape)
# """
# test ContraWR
# """
# from pyhealth.datasets import split_by_patient, get_dataloader
# from pyhealth.trainer import Trainer
# from pyhealth.datasets import SleepEDFDataset
# from pyhealth.tasks import sleep_staging_sleepedf_fn
# # step 1: load signal data
# dataset = SleepEDFDataset(
# root="/srv/local/data/SLEEPEDF/sleep-edf-database-expanded-1.0.0/sleep-cassette",
# dev=True,
# refresh_cache=False,
# )
# # step 2: set task
# sleep_staging_ds = dataset.set_task(sleep_staging_sleepedf_fn)
# sleep_staging_ds.stat()
# print(sleep_staging_ds.input_info)
# # split dataset
# train_dataset, val_dataset, test_dataset = split_by_patient(
# sleep_staging_ds, [0.6, 0.2, 0.2]
# )
# train_dataloader = get_dataloader(train_dataset, batch_size=5, shuffle=True)
# val_dataloader = get_dataloader(val_dataset, batch_size=5, shuffle=False)
# test_dataloader = get_dataloader(test_dataset, batch_size=5, shuffle=False)
# print(
# "loader size: train/val/test",
# len(train_dataset),
# len(val_dataset),
# len(test_dataset),
# )
# batch = next(iter(train_dataloader))
# # step 3: define model
# model = ContraWR(
# sleep_staging_ds,
# feature_keys=["signal"],
# label_key="label",
# mode="multiclass",
# n_fft=128,
# )
# result = model(**batch)
# print(result)
"""
test ContraWR 2
"""
from pyhealth.datasets import SampleSignalDataset, get_dataloader
samples = [
{
"record_id": "SC4001-0",
"patient_id": "SC4001",
"epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-0.pkl",
"label": "W",
},
{
"record_id": "SC4001-0",
"patient_id": "SC4001",
"epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-1.pkl",
"label": "R",
},
]
# dataset
dataset = SampleSignalDataset(samples=samples, dataset_name="test")
# data loader
from pyhealth.datasets import get_dataloader
train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
# model
model = ContraWR(
dataset=dataset,
feature_keys=["signal"],
label_key="label",
mode="multiclass",
).to("cuda:0")
# data batch
data_batch = next(iter(train_loader))
# try the model
ret = model(**data_batch)
print(ret)
# try loss backward
ret["loss"].backward()