Source code for pyhealth.models.transformers_model

from typing import Dict, List

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

from ..datasets import SampleDataset
from .base_model import BaseModel


[docs]class TransformersModel(BaseModel): """Transformers class for Huggingface models.""" def __init__( self, dataset: SampleDataset, model_name: str, dropout: float = 0.1, ): super(TransformersModel, self).__init__( dataset=dataset, ) self.model_name = model_name self.model = AutoModel.from_pretrained( model_name, hidden_dropout_prob=dropout, attention_probs_dropout_prob=dropout, ) assert ( len(self.feature_keys) == 1 ), "Only one feature key is supported if Transformers is initialized" self.feature_key = self.feature_keys[0] assert ( len(self.label_keys) == 1 ), "Only one label key is supported if RNN is initialized" self.label_key = self.label_keys[0] self.tokenizer = AutoTokenizer.from_pretrained(model_name) output_size = self.get_output_size() hidden_dim = self.model.config.hidden_size self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(hidden_dim, output_size)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation.""" # concat the info within one batch (batch, channel, length) x = kwargs[self.feature_key] # TODO: max_length should be a parameter x = self.tokenizer( x, return_tensors="pt", padding=True, truncation=True, max_length=256 ) x = x.to(self.device) # TODO: should not use pooler_output, but use the last hidden state embeddings = self.model(**x).pooler_output logits = self.fc(self.dropout(embeddings)) y_true = kwargs[self.label_key].to(self.device) loss = self.get_loss_function()(logits, y_true) y_prob = self.prepare_y_prob(logits) return { "loss": loss, "y_prob": y_prob, "y_true": y_true, }
if __name__ == "__main__": from pyhealth.datasets import MedicalTranscriptionsDataset, get_dataloader base_dataset = MedicalTranscriptionsDataset( root="/srv/local/data/zw12/raw_data/MedicalTranscriptions" ) sample_dataset = base_dataset.set_task() train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) model = TransformersModel( dataset=sample_dataset, feature_keys=["transcription"], label_key="label", mode="multiclass", model_name="emilyalsentzer/Bio_ClinicalBERT", ) # data batch data_batch = next(iter(train_loader)) # try the model ret = model(**data_batch) print(ret) # try loss backward ret["loss"].backward()