Source code for pyhealth.models.sparcnet

#Author: Joshua Chen
#Paper: Development of Expert-Level Classification of Seizures and Rhythmic and Periodic Patterns During EEG Interpretation
#Paper Link: https://pubmed.ncbi.nlm.nih.gov/36878708/ 
#Description: SparcNet implementation for Pyhealth 2.0
import math
from collections import OrderedDict
from typing import Dict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel


[docs]class DenseLayer(nn.Sequential): """Densely connected layer Args: input_channels: number of input channels growth_rate: rate of growth of channels in this layer bn_size: multiplicative factor for the bottleneck layer (does not affect the output size) drop_rate: dropout rate conv_bias: whether to use bias in convolutional layers batch_norm: whether to use batch normalization Example: >>> x = torch.randn(128, 5, 1000) >>> batch, channels, length = x.shape >>> model = DenseLayer(channels, 5, 2) >>> y = model(x) >>> y.shape torch.Size([128, 10, 1000]) """ def __init__( self, input_channels, growth_rate, bn_size, drop_rate=0.5, conv_bias=True, batch_norm=True, ): super(DenseLayer, self).__init__() if batch_norm: self.add_module("norm1", nn.BatchNorm1d(input_channels)), self.add_module("elu1", nn.ELU()), self.add_module( "conv1", nn.Conv1d( input_channels, bn_size * growth_rate, kernel_size=1, stride=1, bias=conv_bias, ), ), if batch_norm: self.add_module("norm2", nn.BatchNorm1d(bn_size * growth_rate)), self.add_module("elu2", nn.ELU()), self.add_module( "conv2", nn.Conv1d( bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=conv_bias, ), ), self.drop_rate = drop_rate
[docs] def forward(self, x): new_features = super(DenseLayer, self).forward(x) new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return torch.cat([x, new_features], 1)
[docs]class DenseBlock(nn.Sequential): """Densely connected block Args: num_layers: number of layers in this block input_channls: number of input channels growth_rate: rate of growth of channels in this layer bn_size: multiplicative factor for the bottleneck layer (does not affect the output size) drop_rate: dropout rate conv_bias: whether to use bias in convolutional layers batch_norm: whether to use batch normalization Example: >>> x = torch.randn(128, 5, 1000) >>> batch, channels, length = x.shape >>> model = DenseBlock(3, channels, 5, 2) >>> y = model(x) >>> y.shape torch.Size([128, 20, 1000]) """ def __init__( self, num_layers, input_channels, growth_rate, bn_size, drop_rate=0.5, conv_bias=True, batch_norm=True, ): super(DenseBlock, self).__init__() for idx_layer in range(num_layers): layer = DenseLayer( input_channels + idx_layer * growth_rate, growth_rate, bn_size, drop_rate, conv_bias, batch_norm, ) self.add_module("denselayer%d" % (idx_layer + 1), layer)
[docs]class TransitionLayer(nn.Sequential): """pooling transition layer Args: input_channls: number of input channels output_channels: number of output channels conv_bias: whether to use bias in convolutional layers batch_norm: whether to use batch normalization Example: >>> x = torch.randn(128, 5, 1000) >>> model = TransitionLayer(5, 18) >>> y = model(x) >>> y.shape torch.Size([128, 18, 500]) """ def __init__( self, input_channels, output_channels, conv_bias=True, batch_norm=True ): super(TransitionLayer, self).__init__() if batch_norm: self.add_module("norm", nn.BatchNorm1d(input_channels)) self.add_module("elu", nn.ELU()) self.add_module( "conv", nn.Conv1d( input_channels, output_channels, kernel_size=1, stride=1, bias=conv_bias, ), ) self.add_module("pool", nn.AvgPool1d(kernel_size=2, stride=2))
[docs]class SparcNet(BaseModel): """The SparcNet model for sleep staging. Paper: Jin Jing, et al. Development of Expert-level Classification of Seizures and Rhythmic and Periodic Patterns During EEG Interpretation. Neurology 2023. Note: We use one encoder to handle multiple channel together. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. embedding_dim: (not used now) the embedding dimension. Default is 128. hidden_dim: (not used now) the hidden dimension. Default is 128. block_layer: the number of layers in each dense block. Default is 4. growth_rate: the growth rate of each dense layer. Default is 16. bn_size: the bottleneck size of each dense layer. Default is 16. conv_bias: whether to use bias in convolutional layers. Default is True. batch_norm: whether to use batch normalization. Default is True. Examples: >>> import numpy as np >>> from pyhealth.datasets import create_sample_dataset >>> samples = [ ... { ... "patient_id": "p0", ... "visit_id": "v0", ... "signal": np.random.randn(2, 256).astype(np.float32), ... "label": 0, ... }, ... { ... "patient_id": "p1", ... "visit_id": "v0", ... "signal": np.random.randn(2, 256).astype(np.float32), ... "label": 1, ... } ... ] >>> dataset = create_sample_dataset( ... samples=samples, ... input_schema={"signal": "tensor"}, ... output_schema={"label": "multiclass"}, ... dataset_name="test", ... ) >>> >>> from pyhealth.models import SparcNet >>> model = SparcNet(dataset=dataset) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(0.6530, device='cuda:0', grad_fn=<NllLossBackward0>), 'y_prob': tensor([[0.4459, 0.5541], [0.5111, 0.4889]], device='cuda:0', grad_fn=<SoftmaxBackward0>), 'y_true': tensor([1, 1], device='cuda:0'), 'logit': tensor([[-0.2750, -0.0577], [-0.1319, -0.1763]], device='cuda:0', grad_fn=<AddmmBackward0>) } """ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, hidden_dim: int = 128, block_layers=4, growth_rate=16, bn_size=16, drop_rate=0.5, conv_bias=True, batch_norm=True, ): super(SparcNet, self).__init__(dataset=dataset) """ common """ self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim assert len(self.label_keys) == 1, ( "Only one label key is supported if SparcNet is initialized" ) assert len(self.feature_keys) == 1, ( "Only one feature key is supported if SparcNet is initialized" ) """ input statistics """ print(f"\n=== Input data statistics ===") in_channels, length = self._determine_input_channels_length() # input signal size (batch, n_channels, length) print(f"n_channels: {in_channels}") print(f"length: {length}") """ define sparcnet """ # add initial convolutional layer out_channels = 2 ** (math.floor(np.log2(in_channels)) + 1) first_conv = OrderedDict( [ ( "conv0", nn.Conv1d( in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=conv_bias, ), ) ] ) first_conv["norm0"] = nn.BatchNorm1d(out_channels) first_conv["elu0"] = nn.ELU() first_conv["pool0"] = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) self.encoder = nn.Sequential(first_conv) n_channels = out_channels # add dense blocks for n_layer in range(int(math.floor(np.log2(length // 4)))): block = DenseBlock( num_layers=block_layers, input_channels=n_channels, growth_rate=growth_rate, bn_size=bn_size, drop_rate=drop_rate, conv_bias=conv_bias, batch_norm=batch_norm, ) self.encoder.add_module("denseblock%d" % (n_layer + 1), block) # update number of channels after each dense block n_channels = n_channels + block_layers * growth_rate trans = TransitionLayer( input_channels=n_channels, output_channels=n_channels // 2, conv_bias=conv_bias, batch_norm=batch_norm, ) self.encoder.add_module("transition%d" % (n_layer + 1), trans) # update number of channels after each transition layer n_channels = n_channels // 2 """ prediction layer """ output_size = self.get_output_size() self.fc = nn.Linear(n_channels, output_size) # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight.data) elif isinstance(m, nn.BatchNorm1d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() def _determine_input_channels_length(self): for sample in self.dataset: if self.feature_keys[0] not in sample: continue if len(sample[self.feature_keys[0]].shape) == 1: return 1, sample[self.feature_keys[0]].shape[0] if len(sample[self.feature_keys[0]].shape) == 2: return sample[self.feature_keys[0]].shape[0], sample[ self.feature_keys[0] ].shape[1] raise ValueError( f"Invalid shape for feature key {self.feature_keys[0]}: {sample[self.feature_keys[0]].shape}" ) raise ValueError( f"Unable to infer input channels and length from dataset for feature key {self.feature_keys[0]}" )
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation.""" # concat the info within one batch (batch, channel, length) x = kwargs[self.feature_keys[0]].to(self.device) # final layer embedding (batch, embedding) emb = self.encoder(x).view(x.shape[0], -1) # (patient, label_size) logits = self.fc(emb) # obtain y_true, loss, y_prob y_true = kwargs[self.label_keys[0]].to(self.device) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = emb return results