Source code for pyhealth.models.sparcnet

import math
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from pyhealth.datasets import BaseSignalDataset
from pyhealth.models import BaseModel


[docs]class DenseLayer(nn.Sequential): """Densely connected layer Args: input_channels: number of input channels growth_rate: rate of growth of channels in this layer bn_size: multiplicative factor for the bottleneck layer (does not affect the output size) drop_rate: dropout rate conv_bias: whether to use bias in convolutional layers batch_norm: whether to use batch normalization Example: >>> x = torch.randn(128, 5, 1000) >>> batch, channels, length = x.shape >>> model = DenseLayer(channels, 5, 2) >>> y = model(x) >>> y.shape torch.Size([128, 10, 1000]) """ def __init__( self, input_channels, growth_rate, bn_size, drop_rate=0.5, conv_bias=True, batch_norm=True, ): super(DenseLayer, self).__init__() if batch_norm: self.add_module("norm1", nn.BatchNorm1d(input_channels)), self.add_module("elu1", nn.ELU()), self.add_module( "conv1", nn.Conv1d( input_channels, bn_size * growth_rate, kernel_size=1, stride=1, bias=conv_bias, ), ), if batch_norm: self.add_module("norm2", nn.BatchNorm1d(bn_size * growth_rate)), self.add_module("elu2", nn.ELU()), self.add_module( "conv2", nn.Conv1d( bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=conv_bias, ), ), self.drop_rate = drop_rate
[docs] def forward(self, x): new_features = super(DenseLayer, self).forward(x) new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return torch.cat([x, new_features], 1)
[docs]class DenseBlock(nn.Sequential): """Densely connected block Args: num_layers: number of layers in this block input_channls: number of input channels growth_rate: rate of growth of channels in this layer bn_size: multiplicative factor for the bottleneck layer (does not affect the output size) drop_rate: dropout rate conv_bias: whether to use bias in convolutional layers batch_norm: whether to use batch normalization Example: >>> x = torch.randn(128, 5, 1000) >>> batch, channels, length = x.shape >>> model = DenseBlock(3, channels, 5, 2) >>> y = model(x) >>> y.shape torch.Size([128, 20, 1000]) """ def __init__( self, num_layers, input_channels, growth_rate, bn_size, drop_rate=0.5, conv_bias=True, batch_norm=True, ): super(DenseBlock, self).__init__() for idx_layer in range(num_layers): layer = DenseLayer( input_channels + idx_layer * growth_rate, growth_rate, bn_size, drop_rate, conv_bias, batch_norm, ) self.add_module("denselayer%d" % (idx_layer + 1), layer)
[docs]class TransitionLayer(nn.Sequential): """pooling transition layer Args: input_channls: number of input channels output_channels: number of output channels conv_bias: whether to use bias in convolutional layers batch_norm: whether to use batch normalization Example: >>> x = torch.randn(128, 5, 1000) >>> model = TransitionLayer(5, 18) >>> y = model(x) >>> y.shape torch.Size([128, 18, 500]) """ def __init__( self, input_channels, output_channels, conv_bias=True, batch_norm=True ): super(TransitionLayer, self).__init__() if batch_norm: self.add_module("norm", nn.BatchNorm1d(input_channels)) self.add_module("elu", nn.ELU()) self.add_module( "conv", nn.Conv1d( input_channels, output_channels, kernel_size=1, stride=1, bias=conv_bias, ), ) self.add_module("pool", nn.AvgPool1d(kernel_size=2, stride=2))
[docs]class SparcNet(BaseModel): """The SparcNet model for sleep staging. Paper: Jin Jing, et al. Development of Expert-level Classification of Seizures and Rhythmic and Periodic Patterns During EEG Interpretation. Neurology 2023. Note: We use one encoder to handle multiple channel together. Args: dataset: the dataset to train the model. It is used to query certain information such as the set of all tokens. feature_keys: list of keys in samples to use as features, e.g. ["conditions", "procedures"]. label_key: key in samples to use as label (e.g., "drugs"). mode: one of "binary", "multiclass", or "multilabel". embedding_dim: (not used now) the embedding dimension. Default is 128. hidden_dim: (not used now) the hidden dimension. Default is 128. block_layer: the number of layers in each dense block. Default is 4. growth_rate: the growth rate of each dense layer. Default is 16. bn_size: the bottleneck size of each dense layer. Default is 16. conv_bias: whether to use bias in convolutional layers. Default is True. batch_norm: whether to use batch normalization. Default is True. **kwargs: other parameters for the Deepr layer. Examples: >>> from pyhealth.datasets import SampleSignalDataset >>> samples = [ ... { ... "record_id": "SC4001-0", ... "patient_id": "SC4001", ... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-0.pkl", ... "label": "W", ... }, ... { ... "record_id": "SC4001-1", ... "patient_id": "SC4001", ... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-1.pkl", ... "label": "R", ... } ... ] >>> dataset = SampleSignalDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import SparcNet >>> model = SparcNet( ... dataset=dataset, ... feature_keys=["signal"], # dataloader will load the signal from "epoch_path" and put it in "signal" ... label_key="label", ... mode="multiclass", ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(0.6530, device='cuda:0', grad_fn=<NllLossBackward0>), 'y_prob': tensor([[0.4459, 0.5541], [0.5111, 0.4889]], device='cuda:0', grad_fn=<SoftmaxBackward0>), 'y_true': tensor([1, 1], device='cuda:0'), 'logit': tensor([[-0.2750, -0.0577], [-0.1319, -0.1763]], device='cuda:0', grad_fn=<AddmmBackward0>) } """ def __init__( self, dataset: BaseSignalDataset, feature_keys: List[str], label_key: str, mode: str, embedding_dim: int = 128, hidden_dim: int = 128, block_layers=4, growth_rate=16, bn_size=16, drop_rate=0.5, conv_bias=True, batch_norm=True, **kwargs, ): super(SparcNet, self).__init__( dataset=dataset, feature_keys=feature_keys, label_key=label_key, mode=mode, ) """ common """ self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim # TODO: Use more tokens for <gap> for different lengths once the input has such information self.label_tokenizer = self.get_label_tokenizer() """ input statistics """ print(f"\n=== Input data statistics ===") # obtain input signal size signal_info = self.dataset.input_info["signal"] in_channels, length = signal_info["n_channels"], signal_info["length"] # input signal size (batch, n_channels, length) print(f"n_channels: {in_channels}") print(f"length: {length}") """ define sparcnet """ # add initial convolutional layer out_channels = 2 ** (math.floor(np.log2(in_channels)) + 1) first_conv = OrderedDict( [ ( "conv0", nn.Conv1d( in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=conv_bias, ), ) ] ) first_conv["norm0"] = nn.BatchNorm1d(out_channels) first_conv["elu0"] = nn.ELU() first_conv["pool0"] = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) self.encoder = nn.Sequential(first_conv) n_channels = out_channels # add dense blocks for n_layer in np.arange(math.floor(np.log2(length // 4))): block = DenseBlock( num_layers=block_layers, input_channels=n_channels, growth_rate=growth_rate, bn_size=bn_size, drop_rate=drop_rate, conv_bias=conv_bias, batch_norm=batch_norm, ) self.encoder.add_module("denseblock%d" % (n_layer + 1), block) # update number of channels after each dense block n_channels = n_channels + block_layers * growth_rate trans = TransitionLayer( input_channels=n_channels, output_channels=n_channels // 2, conv_bias=conv_bias, batch_norm=batch_norm, ) self.encoder.add_module("transition%d" % (n_layer + 1), trans) # update number of channels after each transition layer n_channels = n_channels // 2 """ prediction layer """ output_size = self.get_output_size(self.label_tokenizer) self.fc = nn.Linear(n_channels, output_size) # Official init from torch repo. for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_(m.weight.data) elif isinstance(m, nn.BatchNorm1d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_()
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation.""" # concat the info within one batch (batch, channel, length) x = torch.tensor( np.array(kwargs[self.feature_keys[0]]), device=self.device ).float() # final layer embedding (batch, embedding) emb = self.encoder(x).view(x.shape[0], -1) # (patient, label_size) logits = self.fc(emb) # obtain y_true, loss, y_prob y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) results = { "loss": loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } if kwargs.get("embed", False): results["embed"] = emb return results
if __name__ == "__main__": """ For dense layer """ # x = torch.randn(128, 5, 1000) # batch, channels, length = x.shape # model = DenseLayer(channels, 5, 2) # y = model(x) # print(y.shape) """ For dense block """ # x = torch.randn(128, 5, 1000) # batch, channels, length = x.shape # model = DenseBlock(3, channels, 5, 2) # y = model(x) # print(y.shape) """ For transition layer """ # x = torch.randn(128, 5, 1000) # batch, channels, length = x.shape # model = TransitionLayer(channels, 18) # y = model(x) # print(y.shape) """ For sparcenet """ from pyhealth.datasets import SampleSignalDataset, get_dataloader samples = [ { "record_id": "SC4001-0", "patient_id": "SC4001", "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-0.pkl", "label": "W", }, { "record_id": "SC4001-0", "patient_id": "SC4001", "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-1.pkl", "label": "R", }, ] # dataset dataset = SampleSignalDataset(samples=samples, dataset_name="test") # data loader from pyhealth.datasets import get_dataloader train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) # model model = SparcNet( dataset=dataset, feature_keys=["signal"], label_key="label", mode="multiclass", ).to("cuda:0") # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()