pyhealth.models.ContraWR#

The separate callable ResBlock2D and the complete ContraWR model.

class pyhealth.models.ResBlock2D(in_channels, out_channels, stride=2, downsample=True, pooling=True)[source]#

Bases: Module

Convolutional Residual Block 2D

This block stacks two convolutional layers with batch normalization, max pooling, dropout, and residual connection.

Parameters:
  • in_channels (int) – number of input channels.

  • out_channels (int) – number of output channels.

  • stride (int) – stride of the convolutional layers.

  • downsample (bool) – whether to use a downsampling residual connection.

  • pooling (bool) – whether to use max pooling.

Example

>>> import torch
>>> from pyhealth.models import ResBlock2D
>>>
>>> model = ResBlock2D(6, 16, 1, True, True)
>>> input_ = torch.randn((16, 6, 28, 150))  # (batch, channel, height, width)
>>> output = model(input_)
>>> output.shape
torch.Size([16, 16, 14, 75])
forward(x)[source]#

Forward propagation.

Parameters:

x – input tensor of shape (batch_size, in_channels, height, width).

Returns:

output tensor of shape (batch_size, out_channels, *, *).

Return type:

out

training: bool#
class pyhealth.models.ContraWR(dataset, feature_keys, label_key, mode, embedding_dim=128, hidden_dim=128, n_fft=128, **kwargs)[source]#

Bases: BaseModel

The encoder model of ContraWR (a supervised model, STFT + 2D CNN layers)

Paper: Yang, Chaoqi, Danica Xiao, M. Brandon Westover, and Jimeng Sun. “Self-supervised eeg representation learning for automatic sleep staging.” arXiv preprint arXiv:2110.15278 (2021).

Note

We use one encoder to handle multiple channel together.

Parameters:
  • dataset (BaseSignalDataset) – the dataset to train the model. It is used to query certain information such as the set of all tokens.

  • feature_keys (List[str]) – list of keys in samples to use as features, e.g. [“conditions”, “procedures”].

  • label_key (str) – key in samples to use as label (e.g., “drugs”).

  • mode (str) – one of “binary”, “multiclass”, or “multilabel”.

  • embedding_dim (int) – the embedding dimension. Default is 128.

  • hidden_dim (int) – the hidden dimension. Default is 128.

  • **kwargs – other parameters for the Deepr layer.

Examples

>>> from pyhealth.datasets import SampleSignalDataset
>>> samples = [
...         {
...             "record_id": "SC4001-0",
...             "patient_id": "SC4001",
...             "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-0.pkl",
...             "label": "W",
...         },
...         {
...             "record_id": "SC4001-1",
...             "patient_id": "SC4001",
...             "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-1.pkl",
...             "label": "R",
...         }
...     ]
>>> dataset = SampleSignalDataset(samples=samples, dataset_name="test")
>>>
>>> from pyhealth.models import ContraWR
>>> model = ContraWR(
...         dataset=dataset,
...         feature_keys=["signal"], # dataloader will load the signal from "epoch_path" and put it in "signal"
...         label_key="label",
...         mode="multiclass",
...     )
>>>
>>> from pyhealth.datasets import get_dataloader
>>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True)
>>> data_batch = next(iter(train_loader))
>>>
>>> ret = model(**data_batch)
>>> print(ret)
{
    'loss': tensor(2.8425, device='cuda:0', grad_fn=<NllLossBackward0>),
    'y_prob': tensor([[0.9345, 0.0655],
                    [0.9482, 0.0518]], device='cuda:0', grad_fn=<SoftmaxBackward0>),
    'y_true': tensor([1, 1], device='cuda:0'),
    'logit': tensor([[ 0.1472, -2.5104],
                    [2.1584, -0.7481]], device='cuda:0', grad_fn=<AddmmBackward0>)
}
>>>
cal_encoder_stat()[source]#

obtain the convolution encoder initialization statistics

Note

We show an example to illustrate the encoder statistics. input x:

  • torch.Size([5, 7, 3000])

after stft transform
  • torch.Size([5, 7, 65, 90])

we design the first CNN (out_channels = 8)
  • torch.Size([5, 8, 16, 22])

  • here: 8 * 16 * 22 > 256, we continute the convolution

we design the second CNN (out_channels = 16)
  • torch.Size([5, 16, 4, 5])

  • here: 16 * 4 * 5 > 256, we continute the convolution

we design the second CNN (out_channels = 32)
  • torch.Size([5, 32, 1, 1])

  • here: 32 * 1 * 1, we stop the convolution

output:
  • channels = [7, 8, 16, 32]

  • emb_size = 32 * 1 * 1 = 32

torch_stft(X)[source]#

torch short time fourier transform (STFT)

Parameters:

X – (batch, n_channels, length)

Returns:

(batch, n_channels, freq, time_steps)

Return type:

signal

forward(**kwargs)[source]#

Forward propagation.

Return type:

Dict[str, Tensor]

training: bool#