pyhealth.models.ContraWR#
The separate callable ResBlock2D and the complete ContraWR model.
- class pyhealth.models.ResBlock2D(in_channels, out_channels, stride=2, downsample=True, pooling=True)[source]#
Bases:
Module
Convolutional Residual Block 2D
This block stacks two convolutional layers with batch normalization, max pooling, dropout, and residual connection.
- Parameters:
Example
>>> import torch >>> from pyhealth.models import ResBlock2D >>> >>> model = ResBlock2D(6, 16, 1, True, True) >>> input_ = torch.randn((16, 6, 28, 150)) # (batch, channel, height, width) >>> output = model(input_) >>> output.shape torch.Size([16, 16, 14, 75])
- class pyhealth.models.ContraWR(dataset, feature_keys, label_key, mode, embedding_dim=128, hidden_dim=128, n_fft=128, **kwargs)[source]#
Bases:
BaseModel
The encoder model of ContraWR (a supervised model, STFT + 2D CNN layers)
Paper: Yang, Chaoqi, Danica Xiao, M. Brandon Westover, and Jimeng Sun. “Self-supervised eeg representation learning for automatic sleep staging.” arXiv preprint arXiv:2110.15278 (2021).
Note
We use one encoder to handle multiple channel together.
- Parameters:
dataset (
BaseSignalDataset
) – the dataset to train the model. It is used to query certain information such as the set of all tokens.feature_keys (
List
[str
]) – list of keys in samples to use as features, e.g. [“conditions”, “procedures”].label_key (
str
) – key in samples to use as label (e.g., “drugs”).mode (
str
) – one of “binary”, “multiclass”, or “multilabel”.embedding_dim (
int
) – the embedding dimension. Default is 128.hidden_dim (
int
) – the hidden dimension. Default is 128.**kwargs – other parameters for the Deepr layer.
Examples
>>> from pyhealth.datasets import SampleSignalDataset >>> samples = [ ... { ... "record_id": "SC4001-0", ... "patient_id": "SC4001", ... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-0.pkl", ... "label": "W", ... }, ... { ... "record_id": "SC4001-1", ... "patient_id": "SC4001", ... "epoch_path": "/home/chaoqiy2/.cache/pyhealth/datasets/2f06a9232e54254cbcb4b62624294d71/SC4001-1.pkl", ... "label": "R", ... } ... ] >>> dataset = SampleSignalDataset(samples=samples, dataset_name="test") >>> >>> from pyhealth.models import ContraWR >>> model = ContraWR( ... dataset=dataset, ... feature_keys=["signal"], # dataloader will load the signal from "epoch_path" and put it in "signal" ... label_key="label", ... mode="multiclass", ... ) >>> >>> from pyhealth.datasets import get_dataloader >>> train_loader = get_dataloader(dataset, batch_size=2, shuffle=True) >>> data_batch = next(iter(train_loader)) >>> >>> ret = model(**data_batch) >>> print(ret) { 'loss': tensor(2.8425, device='cuda:0', grad_fn=<NllLossBackward0>), 'y_prob': tensor([[0.9345, 0.0655], [0.9482, 0.0518]], device='cuda:0', grad_fn=<SoftmaxBackward0>), 'y_true': tensor([1, 1], device='cuda:0'), 'logit': tensor([[ 0.1472, -2.5104], [2.1584, -0.7481]], device='cuda:0', grad_fn=<AddmmBackward0>) } >>>
- cal_encoder_stat()[source]#
obtain the convolution encoder initialization statistics
Note
We show an example to illustrate the encoder statistics. input x:
torch.Size([5, 7, 3000])
- after stft transform
torch.Size([5, 7, 65, 90])
- we design the first CNN (out_channels = 8)
torch.Size([5, 8, 16, 22])
here: 8 * 16 * 22 > 256, we continute the convolution
- we design the second CNN (out_channels = 16)
torch.Size([5, 16, 4, 5])
here: 16 * 4 * 5 > 256, we continute the convolution
- we design the second CNN (out_channels = 32)
torch.Size([5, 32, 1, 1])
here: 32 * 1 * 1, we stop the convolution
- output:
channels = [7, 8, 16, 32]
emb_size = 32 * 1 * 1 = 32