Source code for pyhealth.models.transformer_deid
"""
PyHealth model for transformer-based clinical text de-identification.
Performs token-level NER to detect PHI (protected health information)
in clinical notes using a pre-trained transformer with a classification
head.
Paper: Johnson, Alistair E.W., et al. "Deidentification of free-text
medical records using pre-trained bidirectional transformers."
Proceedings of the ACM Conference on Health, Inference, and
Learning (CHIL), 2020.
Paper link:
https://doi.org/10.1145/3368555.3384455
Model structure (dropout + linear head) follows PyHealth's
TransformersModel (pyhealth/models/transformers_model.py), adapted
for token-level classification instead of sequence-level.
Subword alignment follows the standard HuggingFace token
classification pattern (see BertForTokenClassification).
Author:
Matt McKenna (mtm16@illinois.edu)
"""
import logging
from typing import Dict, List
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from ..datasets import SampleDataset
from .base_model import BaseModel
logger = logging.getLogger(__name__)
# 7 PHI categories with BIO prefix, plus O for non-PHI.
LABEL_VOCAB = {
"O": 0,
"B-AGE": 1, "I-AGE": 2,
"B-CONTACT": 3, "I-CONTACT": 4,
"B-DATE": 5, "I-DATE": 6,
"B-ID": 7, "I-ID": 8,
"B-LOCATION": 9, "I-LOCATION": 10,
"B-NAME": 11, "I-NAME": 12,
"B-PROFESSION": 13, "I-PROFESSION": 14,
}
# Cross-entropy ignores positions with this index (PyTorch convention).
IGNORE_INDEX = -100
def align_labels(
word_ids: List[int | None],
word_labels: List[int],
) -> List[int]:
"""Align word-level labels to subword tokens.
BERT/RoBERTa tokenizers split words into subwords. For example,
"Smith" might become ["Sm", "##ith"]. This function assigns the
word's label to the first subtoken and IGNORE_INDEX to the rest,
so the loss function skips non-first subtokens. Special tokens
([CLS], [SEP], padding) have word_id=None and also get
IGNORE_INDEX.
Args:
word_ids: Output of tokenizer.word_ids(). None for special
tokens, integer word index for real tokens.
word_labels: Label index for each word in the original text.
Returns:
List of label indices, one per subtoken. Non-first subtokens
and special tokens are set to IGNORE_INDEX (-100).
"""
aligned = []
prev_word_id = None
for word_id in word_ids:
if word_id is None:
# Special token ([CLS], [SEP], padding).
aligned.append(IGNORE_INDEX)
elif word_id != prev_word_id:
# First subtoken of a word: use the word's label.
aligned.append(word_labels[word_id])
else:
# Non-first subtoken: ignore during loss computation.
aligned.append(IGNORE_INDEX)
prev_word_id = word_id
return aligned
[docs]class TransformerDeID(BaseModel):
"""Transformer-based token classifier for clinical text de-identification.
Uses a pre-trained transformer encoder with a linear classification
head to predict BIO-tagged PHI labels for each token.
Args:
dataset: A SampleDataset from set_task().
model_name: HuggingFace model name. Default "bert-base-uncased".
max_length: Maximum token sequence length. Default 512.
dropout: Dropout rate for the classification head. Default 0.1.
Examples:
>>> from pyhealth.datasets import PhysioNetDeIDDataset
>>> from pyhealth.tasks import DeIDNERTask
>>> from pyhealth.models import TransformerDeID
>>> dataset = PhysioNetDeIDDataset(root="/path/to/data")
>>> samples = dataset.set_task(DeIDNERTask())
>>> model = TransformerDeID(dataset=samples) # BERT
>>> model = TransformerDeID(dataset=samples, model_name="roberta-base")
"""
def __init__(
self,
dataset: SampleDataset,
model_name: str = "bert-base-uncased",
max_length: int = 512,
dropout: float = 0.1,
):
super(TransformerDeID, self).__init__(dataset=dataset)
assert len(self.feature_keys) == 1, (
"TransformerDeID expects exactly one input feature (text)."
)
assert len(self.label_keys) == 1, (
"TransformerDeID expects exactly one label key."
)
self.feature_key = self.feature_keys[0]
self.label_key = self.label_keys[0]
self.model_name = model_name
self.max_length = max_length
self.label_vocab = LABEL_VOCAB
self.num_labels = len(LABEL_VOCAB)
# add_prefix_space=True is required for RoBERTa when using
# is_split_into_words=True in the forward pass.
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, add_prefix_space=True
)
self.encoder = AutoModel.from_pretrained(
model_name,
hidden_dropout_prob=dropout,
attention_probs_dropout_prob=dropout,
)
hidden_size = self.encoder.config.hidden_size
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(hidden_size, self.num_labels)
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward pass.
Args:
**kwargs: Must contain self.feature_key (list of
space-joined token strings) and self.label_key
(list of space-joined BIO label strings).
Returns:
Dict with keys: loss, logit, y_prob, y_true.
"""
texts: List[str] = kwargs[self.feature_key]
label_strings: List[str] = kwargs[self.label_key]
# Tokenize with is_split_into_words=True so the tokenizer
# knows word boundaries and word_ids() works correctly.
words_batch = [t.split(" ") for t in texts]
encoding = self.tokenizer(
words_batch,
is_split_into_words=True,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
# Convert word-level label strings to indices, then align
# to subword tokens. Positions that should be ignored during
# loss (special tokens, non-first subtokens, padding) get
# IGNORE_INDEX (-100), which cross-entropy skips.
aligned_labels = []
for i, label_str in enumerate(label_strings):
word_labels = [
self.label_vocab[lbl] for lbl in label_str.split(" ")
]
word_ids = encoding.word_ids(batch_index=i)
aligned_labels.append(align_labels(word_ids, word_labels))
labels = torch.tensor(aligned_labels, dtype=torch.long)
# Move to device
input_ids = encoding["input_ids"].to(self.device)
attention_mask = encoding["attention_mask"].to(self.device)
labels = labels.to(self.device)
# Encoder -> dropout -> classifier (per-token logits)
hidden_states = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
).last_hidden_state
logits = self.classifier(self.dropout(hidden_states))
# Token-level cross-entropy, ignoring padded/special positions.
# We can't use BaseModel.get_loss_function() because it assumes
# one label per sample. Instead we call cross_entropy directly
# with ignore_index to skip special tokens and non-first subtokens.
# Flatten + ignore_index pattern from HuggingFace's
# BertForTokenClassification.forward().
loss = nn.functional.cross_entropy(
logits.view(-1, self.num_labels),
labels.view(-1),
ignore_index=IGNORE_INDEX,
)
# Per-token probabilities via softmax.
y_prob = torch.softmax(logits, dim=-1)
return {
"loss": loss,
"logit": logits,
"y_prob": y_prob,
"y_true": labels,
}
[docs] def deidentify(self, text: str, redact: str = "[REDACTED]") -> str:
"""Replace PHI in a clinical note with a redaction marker.
Args:
text: Raw clinical note as a string.
redact: Replacement string for PHI tokens.
Returns:
The note with PHI tokens replaced.
Example::
>>> model.deidentify("Patient John Smith was seen")
'Patient [REDACTED] [REDACTED] was seen'
"""
words = text.split()
# Forward pass with dummy labels (all O) since we only
# need predictions, not loss.
dummy_labels = " ".join(["O"] * len(words))
self.eval()
with torch.no_grad():
result = self(text=[text], labels=[dummy_labels])
preds = result["logit"][0].argmax(dim=-1)
y_true = result["y_true"][0]
# Map predictions back to words using the non-ignored positions.
word_idx = 0
output = []
for j in range(len(preds)):
if y_true[j].item() == IGNORE_INDEX:
continue
if preds[j].item() != 0: # non-O = PHI
output.append(redact)
else:
output.append(words[word_idx])
word_idx += 1
return " ".join(output)