"""Text embedding module for multimodal PyHealth pipelines.
This module provides a Transformer-based text encoder for clinical/medical text.
It is designed to integrate with PyHealth's multimodal fusion architecture.
Key Features:
- Uses pretrained medical language models (default: Bio_ClinicalBERT)
- Splits long texts into 128-token chunks for consistent encoding
- Projects embeddings to a shared dimension for multimodal concatenation
- Returns attention masks compatible with PyHealth's TransformerLayer
Dependencies:
- transformers >= 4.20.0 (pinned as ~=4.53.2 in pyproject.toml)
- torch
Example:
>>> from pyhealth.models.text_embedding import TextEmbedding
>>> encoder = TextEmbedding(embedding_dim=256)
>>> embeddings, mask = encoder(["Patient has fever.", "Follow-up."])
>>> embeddings.shape # [2, T, 256]
"""
from typing import List, Optional, Tuple, Union
import logging
import warnings
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
logger = logging.getLogger(__name__)
[docs]class TextEmbedding(nn.Module):
"""Encodes clinical text into embeddings for multimodal fusion.
This module wraps a pretrained Hugging Face transformer (default:
Bio_ClinicalBERT) and handles the complexities of encoding long clinical
notes that exceed typical transformer context windows.
Input/Output:
Input: List[str] or str - raw text strings
Output: (embeddings, mask) tuple or just embeddings tensor
embeddings: torch.Tensor of shape [batch, seq_len, embedding_dim]
mask: torch.BoolTensor of shape [batch, seq_len] where True=valid
Chunking Behavior:
Long texts are split into non-overlapping chunks. Each chunk:
1. Contains at most (chunk_size - 2) content tokens
2. Gets [CLS] prepended and [SEP] appended
3. Is encoded independently by the transformer
4. Embeddings are concatenated along the sequence dimension
Example: A 300-token note with chunk_size=128 becomes 3 chunks:
Chunk 1: [CLS] + tokens[0:126] + [SEP] = 128 tokens
Chunk 2: [CLS] + tokens[126:252] + [SEP] = 128 tokens
Chunk 3: [CLS] + tokens[252:300] + [SEP] = 50 tokens
Pooling Modes:
- "none" (default): All token embeddings [B, T, E'] where T = total tokens
- "cls": One [CLS] embedding per chunk [B, C, E'] where C = num chunks
- "mean": Mean-pooled embedding per chunk [B, C, E']
Design Decisions:
return_mask parameter (backward compatibility):
Earlier versions returned only embeddings. Adding mask return as
default would break existing callers. The `return_mask=True` default
provides the mask for new code, while `return_mask=False` preserves
old behavior.
max_chunks parameter (performance guardrail):
Clinical notes can be extremely long (10,000+ tokens). Without a
limit, this causes:
- Memory explosion (O(chunks * chunk_size * hidden_dim))
- Silent OOMs in production
- Unexpectedly slow inference
Default max_chunks=64 caps output at approximately 8K tokens. A UserWarning
alerts when truncation occurs. Users can:
- Increase max_chunks if memory permits
- Pre-summarize long notes before encoding
- Use chunk-level pooling instead of token-level
freeze parameter:
Medical transformers are expensive to fine-tune. For multimodal
fusion where the text encoder is one component among several, freezing
prevents catastrophic forgetting and reduces GPU memory by approximately 50%.
Mask Convention:
Returns torch.bool tensor matching PyHealth's TransformerLayer:
- True (or 1) = valid token position
- False (or 0) = padding position
TransformerLayer uses: scores.masked_fill(mask == 0, -1e9)
So True positions are attended, False positions are masked out.
Args:
model_name: Hugging Face model identifier.
Default: "emilyalsentzer/Bio_ClinicalBERT" (clinical domain)
Alternatives: "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
embedding_dim: Output embedding dimension (E'). Default: 128.
Should match other modality encoders for concatenation.
chunk_size: Tokens per chunk including [CLS]/[SEP]. Default: 128.
Matches the "128 token text bits" from the multimodal study.
max_chunks: Maximum chunks to keep. Default: 64.
Set to None for unlimited (use with caution on long texts).
pooling: How to aggregate token embeddings. Default: "none".
"none" = all tokens, "cls" = [CLS] per chunk, "mean" = mean per chunk.
freeze: If True, freeze transformer weights. Default: False.
Recommended for multimodal fusion to prevent overfitting.
return_mask: If True, return (embeddings, mask) tuple. Default: True.
Set to False for backward compatibility with single-tensor callers.
Example:
Basic usage with default parameters:
>>> encoder = TextEmbedding(embedding_dim=256)
>>> texts = ["Patient presents with chest pain.", "Routine checkup."]
>>> embeddings, mask = encoder(texts)
>>> embeddings.shape
torch.Size([2, 12, 256]) # [batch=2, tokens, dim=256]
>>> mask.shape
torch.Size([2, 12]) # [batch=2, tokens]
Using chunk-level pooling for efficiency:
>>> encoder = TextEmbedding(pooling="cls", embedding_dim=128)
>>> long_note = "..." * 1000 # Very long clinical note
>>> emb, mask = encoder([long_note])
>>> emb.shape # [1, num_chunks, 128] instead of [1, thousands, 128]
Backward-compatible single tensor return:
>>> encoder = TextEmbedding(return_mask=False)
>>> embeddings = encoder(["Test"]) # Just tensor, no tuple
"""
def __init__(
self,
model_name: str = "emilyalsentzer/Bio_ClinicalBERT",
embedding_dim: int = 128,
chunk_size: int = 128,
max_chunks: Optional[int] = 64,
pooling: str = "none",
freeze: bool = False,
return_mask: bool = True,
):
"""Initialize the text embedding module.
Loads the pretrained tokenizer and transformer model from Hugging Face.
Creates a projection layer to map transformer hidden size to embedding_dim.
Raises:
ValueError: If pooling is not one of "none", "cls", "mean".
ValueError: If chunk_size < 4 (need room for [CLS], [SEP], content).
"""
super().__init__()
self.model_name = model_name
self.embedding_dim = embedding_dim
self.chunk_size = chunk_size
self.max_chunks = max_chunks
self.pooling = pooling
self.return_mask = return_mask
if pooling not in ("none", "cls", "mean"):
raise ValueError(f"pooling must be 'none', 'cls', or 'mean', got {pooling}")
if chunk_size < 4:
raise ValueError(f"chunk_size must be >= 4, got {chunk_size}")
# Load tokenizer and model from Hugging Face
# First use downloads ~420MB to HF_HOME or ~/.cache/huggingface/
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.transformer = AutoModel.from_pretrained(model_name)
if freeze:
for param in self.transformer.parameters():
param.requires_grad = False
# Projection: transformer hidden_size (e.g., 768) → embedding_dim (e.g., 128)
# This aligns text embeddings with other modalities in a shared E' space
self.fc = nn.Linear(self.transformer.config.hidden_size, embedding_dim)
def _chunk_and_encode(
self, text: str, device: torch.device
) -> torch.Tensor:
"""Tokenize, chunk, and encode a single text string.
This is the core encoding logic. For a single text:
1. Tokenize without truncation to get all tokens
2. Split into chunks of (chunk_size - 2) tokens each
3. Add [CLS] and [SEP] to each chunk
4. Batch-encode all chunks through the transformer
5. Apply the projection layer
6. Return embeddings based on pooling mode
Args:
text: A single text string to encode. Can be empty.
device: torch.device to place output tensors on.
Returns:
torch.Tensor: Encoded embeddings.
- pooling="none": shape [total_tokens, embedding_dim]
- pooling="cls": shape [num_chunks, embedding_dim]
- pooling="mean": shape [num_chunks, embedding_dim]
Side Effects:
Emits UserWarning if text exceeds max_chunks and gets truncated.
"""
# Step 1: Tokenize without truncation to get ALL tokens
tokens = self.tokenizer(
text,
add_special_tokens=False, # We add [CLS]/[SEP] manually per chunk
return_tensors=None, # Return Python list, not tensor
)
input_ids = tokens["input_ids"]
# Step 2: Split into chunks, reserving 2 tokens for [CLS] and [SEP]
effective_chunk_size = self.chunk_size - 2
chunks = []
for i in range(0, len(input_ids), effective_chunk_size):
chunk_ids = input_ids[i : i + effective_chunk_size]
# Add special tokens: [CLS] content... [SEP]
chunk_ids = [self.tokenizer.cls_token_id] + chunk_ids + [self.tokenizer.sep_token_id]
chunks.append(chunk_ids)
# Handle empty text edge case
if not chunks:
chunks = [[self.tokenizer.cls_token_id, self.tokenizer.sep_token_id]]
# Step 3: Apply max_chunks limit (performance guardrail)
# Rationale: Clinical notes can be 10K+ tokens. Without a cap:
# - Memory usage explodes (each chunk needs transformer forward pass)
# - Silent OOMs in production environments
# - Inference time becomes unpredictable
# We warn rather than silently truncate so users can adjust.
if self.max_chunks is not None and len(chunks) > self.max_chunks:
original_chunks = len(chunks)
chunks = chunks[: self.max_chunks]
warnings.warn(
f"Text produced {original_chunks} chunks, truncated to {self.max_chunks}. "
f"Consider increasing max_chunks or summarizing input.",
UserWarning,
)
# Step 4: Pad chunks to uniform length for batched encoding
max_len = max(len(c) for c in chunks)
padded = []
attention_masks = []
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is None:
pad_token_id = 0 # Fallback for tokenizers without explicit pad token
for c in chunks:
pad_len = max_len - len(c)
attention_masks.append([1] * len(c) + [0] * pad_len)
padded.append(c + [pad_token_id] * pad_len)
input_ids_tensor = torch.tensor(padded, dtype=torch.long, device=device)
attention_mask_tensor = torch.tensor(attention_masks, dtype=torch.long, device=device)
# Step 5: Encode all chunks through transformer in one forward pass
outputs = self.transformer(
input_ids=input_ids_tensor,
attention_mask=attention_mask_tensor,
)
hidden_states = outputs.last_hidden_state # [num_chunks, max_len, hidden_size]
# Step 6: Project from hidden_size to embedding_dim
projected = self.fc(hidden_states) # [num_chunks, max_len, embedding_dim]
# Step 7: Apply pooling strategy
if self.pooling == "cls":
# Return [CLS] token (position 0) from each chunk
return projected[:, 0, :] # [num_chunks, embedding_dim]
elif self.pooling == "mean":
# Mean pool over non-padding positions
mask = attention_mask_tensor.unsqueeze(-1).float() # [num_chunks, max_len, 1]
summed = (projected * mask).sum(dim=1) # [num_chunks, embedding_dim]
lengths = mask.sum(dim=1).clamp(min=1) # [num_chunks, 1]
return summed / lengths # [num_chunks, embedding_dim]
else: # pooling == "none"
# Concatenate all non-padding tokens from all chunks
all_embeddings = []
for i, mask in enumerate(attention_masks):
valid_len = sum(mask)
all_embeddings.append(projected[i, :valid_len, :])
return torch.cat(all_embeddings, dim=0) # [total_tokens, embedding_dim]
[docs] def forward(
self, text: Union[List[str], str]
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Encode a batch of texts into embeddings.
Main entry point for encoding text. Handles batching, padding across
samples, and optionally returning attention masks.
Args:
text: Single string or list of strings to encode.
Each string can be any length; chunking handles overflow.
Returns:
Depends on self.return_mask setting:
If return_mask=True (default):
Tuple[torch.Tensor, torch.Tensor]:
embeddings: Shape [B, T, E'] where:
B = batch size (number of input strings)
T = max sequence length across batch (tokens or chunks)
E' = embedding_dim
mask: Shape [B, T], dtype=torch.bool
True at positions with valid embeddings
False at padding positions
If return_mask=False (backward compatibility):
torch.Tensor: Just the embeddings tensor [B, T, E']
Note:
The return_mask parameter exists for backward compatibility.
New code should use the default return_mask=True to get masks
needed for downstream attention layers.
Example:
>>> encoder = TextEmbedding(embedding_dim=128)
>>> emb, mask = encoder(["Hello world", "A longer text here"])
>>> emb.shape # [2, T, 128] where T is max tokens
>>> mask.shape # [2, T]
>>> mask[0].sum() # Number of valid tokens in first sample
"""
# Normalize single string to list
if isinstance(text, str):
text = [text]
# Get device from transformer parameters
device = next(self.transformer.parameters()).device
# Encode each text independently (chunking happens inside)
batch_embeddings = []
for t in text:
emb = self._chunk_and_encode(t, device)
batch_embeddings.append(emb)
# Find max sequence length for padding across batch
max_seq = max(e.shape[0] for e in batch_embeddings)
# Pad all samples to max_seq and build masks
padded = []
masks = []
for e in batch_embeddings:
seq_len = e.shape[0]
pad_len = max_seq - seq_len
# Pad embedding tensor with zeros
if pad_len > 0:
padding = torch.zeros(pad_len, self.embedding_dim, device=device)
e = torch.cat([e, padding], dim=0)
padded.append(e)
# Build boolean mask: True for valid positions, False for padding
mask = torch.cat([
torch.ones(seq_len, dtype=torch.bool, device=device),
torch.zeros(pad_len, dtype=torch.bool, device=device),
])
masks.append(mask)
# Stack into batch tensors
embeddings = torch.stack(padded, dim=0) # [B, max_seq, embedding_dim]
mask = torch.stack(masks, dim=0) # [B, max_seq]
# Return based on backward compat setting
if self.return_mask:
return embeddings, mask
else:
return embeddings