Source code for pyhealth.models.embedding

from __future__ import annotations

from typing import Dict, Any, Optional, Union
import os

import torch
import torch.nn as nn

from ..datasets import SampleDataset
from ..processors import (
    MultiHotProcessor,
    NestedFloatsProcessor,
    NestedSequenceProcessor,
    SequenceProcessor,
    StageNetProcessor,
    StageNetTensorProcessor,
    TensorProcessor,
    TimeseriesProcessor,
    DeepNestedSequenceProcessor,
    DeepNestedFloatsProcessor,
)
from .base_model import BaseModel


def _iter_text_vectors(
    path: str,
    embedding_dim: int,
    wanted_tokens: set[str],
    encoding: str = "utf-8",
) -> Dict[str, torch.Tensor]:
    """Loads word vectors from a text file (e.g., GloVe) for a subset of tokens.

    Expected format: one token per line followed by embedding_dim floats.

    This function reads the file line-by-line and only retains vectors for
    tokens present in `wanted_tokens`.
    """

    if not os.path.exists(path):
        raise FileNotFoundError(f"pretrained embedding file not found: {path}")

    vectors: Dict[str, torch.Tensor] = {}
    with open(path, "r", encoding=encoding) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            # token + embedding_dim values
            if len(parts) < embedding_dim + 1:
                continue
            token = parts[0]
            if token not in wanted_tokens:
                continue
            try:
                vec = torch.tensor(
                    [float(x) for x in parts[1 : embedding_dim + 1]],
                    dtype=torch.float,
                )
            except ValueError:
                continue
            vectors[token] = vec
    return vectors


def init_embedding_with_pretrained(
    embedding: nn.Embedding,
    code_vocab: Dict[Any, int],
    pretrained_path: str,
    embedding_dim: int,
    pad_token: str = "<pad>",
    unk_token: str = "<unk>",
    normalize: bool = False,
    freeze: bool = False,
) -> int:
    """Initializes an nn.Embedding from a pretrained text-vector file.

    Tokens not found in the pretrained file are left as the module's existing
    random initialization.

    Returns:
        int: number of tokens successfully initialized from the file.
    """

    # Build wanted token set (stringified)
    vocab_tokens = {str(t) for t in code_vocab.keys()}
    vectors = _iter_text_vectors(pretrained_path, embedding_dim, vocab_tokens)

    loaded = 0
    with torch.no_grad():
        for tok, idx in code_vocab.items():
            tok_s = str(tok)
            if tok_s in vectors:
                vec = vectors[tok_s]
                if normalize:
                    vec = vec / (vec.norm(p=2) + 1e-12)
                embedding.weight[idx].copy_(vec)
                loaded += 1

        # Ensure pad row is zero
        if pad_token in code_vocab:
            embedding.weight[code_vocab[pad_token]].zero_()
        # If embedding has a padding_idx, keep it consistent
        if embedding.padding_idx is not None:
            embedding.weight[embedding.padding_idx].zero_()

    if freeze:
        embedding.weight.requires_grad_(False)

    return loaded


[docs]class EmbeddingModel(BaseModel): """ EmbeddingModel is responsible for creating embedding layers for different types of input data. This model automatically creates appropriate embedding transformations based on the processor type: - SequenceProcessor: nn.Embedding Input: (batch, seq_len) Output: (batch, seq_len, embedding_dim) - NestedSequenceProcessor: nn.Embedding Input: (batch, num_visits, max_codes_per_visit) Output: (batch, num_visits, max_codes_per_visit, embedding_dim) - DeepNestedSequenceProcessor: nn.Embedding Input: (batch, num_groups, num_visits, max_codes_per_visit) Output: (batch, num_groups, num_visits, max_codes_per_visit, embedding_dim) - TimeseriesProcessor / NestedFloatsProcessor / DeepNestedFloatsProcessor / StageNetTensorProcessor: nn.Linear over the last dimension Input: (..., size) Output: (..., embedding_dim) - TensorProcessor: nn.Linear (size inferred from first sample) - MultiHotProcessor: nn.Linear over multi-hot vector """ def __init__( self, dataset: SampleDataset, embedding_dim: int = 128, pretrained_emb_path: Optional[Union[str, Dict[str, str]]] = None, freeze_pretrained: bool = False, normalize_pretrained: bool = False, ): super().__init__(dataset) self.embedding_dim = embedding_dim self.embedding_layers = nn.ModuleDict() for field_name, processor in self.dataset.input_processors.items(): # Deep categorical: use special module that collapses last dim to embedding_dim # Regular categorical sequences -> nn.Embedding (adds embedding dim) if isinstance( processor, ( SequenceProcessor, StageNetProcessor, NestedSequenceProcessor, DeepNestedSequenceProcessor, ), ): vocab_size = len(processor.code_vocab) if isinstance( processor, (NestedSequenceProcessor, DeepNestedSequenceProcessor) ): self.embedding_layers[field_name] = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0, ) else: self.embedding_layers[field_name] = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0, ) # Optional pretrained initialization (e.g., GloVe). if pretrained_emb_path is not None: if isinstance(pretrained_emb_path, str): path = pretrained_emb_path else: path = pretrained_emb_path.get(field_name) if path: init_embedding_with_pretrained( self.embedding_layers[field_name], processor.code_vocab, path, embedding_dim=embedding_dim, normalize=normalize_pretrained, freeze=freeze_pretrained, ) # Numeric features (including deep nested floats) -> nn.Linear over last dim elif isinstance( processor, ( TimeseriesProcessor, StageNetTensorProcessor, NestedFloatsProcessor, DeepNestedFloatsProcessor, ), ): # Assuming processor.size() returns the last-dim size in_features = processor.size() self.embedding_layers[field_name] = nn.Linear( in_features=in_features, out_features=embedding_dim ) elif isinstance(processor, TensorProcessor): # Infer size from first sample sample_tensor = None for sample in dataset: if field_name in sample: sample_tensor = processor.process(sample[field_name]) break if sample_tensor is not None: input_size = ( sample_tensor.shape[-1] if sample_tensor.dim() > 0 else 1 ) self.embedding_layers[field_name] = nn.Linear( in_features=input_size, out_features=embedding_dim ) elif isinstance(processor, MultiHotProcessor): num_categories = processor.size() self.embedding_layers[field_name] = nn.Linear( in_features=num_categories, out_features=embedding_dim ) # Smart Processor (Token-based) -> Transformers elif hasattr(processor, "is_token") and processor.is_token(): try: from transformers import AutoModel except ImportError: raise ImportError( "Please install `transformers` to use token-based processors." ) # Load the model self.embedding_layers[field_name] = AutoModel.from_pretrained( processor.tokenizer_model ) # Check if we need projection if ( self.embedding_layers[field_name].config.hidden_size != self.embedding_dim ): self.embedding_layers[f"{field_name}_proj"] = nn.Linear( self.embedding_layers[field_name].config.hidden_size, self.embedding_dim, ) else: print( "Warning: No embedding created for field due to lack of compatible processor:", field_name, )
[docs] def forward( self, inputs: Dict[str, torch.Tensor], masks: Dict[str, torch.Tensor] = None, output_mask: bool = False, ) -> ( Dict[str, torch.Tensor] | tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]] ): embedded: Dict[str, torch.Tensor] = {} out_masks: Dict[str, torch.Tensor] = {} if output_mask else None for field_name, tensor in inputs.items(): processor = self.dataset.input_processors.get(field_name, None) if field_name not in self.embedding_layers: # No embedding layer -> passthrough embedded[field_name] = tensor continue # Check if it's a transformer model layer = self.embedding_layers[field_name] # Check for transformers.PreTrainedModel (but without importing if possible, use class name check) # or check if it has 'config' attribute if hasattr(layer, "config") and hasattr(layer, "forward"): # It's likely a transformer tensor = tensor.to(self.device).long() # Ensure LongTensor for IDs mask = None if masks is not None and field_name in masks: mask = masks[field_name].to(self.device) # Handle 3D input (Batch, Num_Notes, Seq_Len) is_3d = inputs[field_name].dim() == 3 if is_3d: b, n, l = inputs[field_name].shape tensor = tensor.view(b * n, l) if mask is not None: mask = mask.view(b * n, l) # Forward pass through transformer output = layer(input_ids=tensor, attention_mask=mask) x = output.last_hidden_state # (Batch, Seq, Hidden) if is_3d: # If we had 3D input, we MUST pool the sequence dim (L) to get one vector per note # Resulting shape: (B, N, H) # Pool L dim -> (B*N, H) using CLS token (index 0) x = x[:, 0, :] # Check projections if f"{field_name}_proj" in self.embedding_layers: x = self.embedding_layers[f"{field_name}_proj"](x) x = x.view(b, n, -1) else: # 2D input (Batch, Seq) -> (Batch, Seq, Hidden) # No pooling, treating as sequence of tokens (word embeddings) if f"{field_name}_proj" in self.embedding_layers: x = self.embedding_layers[f"{field_name}_proj"](x) embedded[field_name] = x else: # Standard layers tensor = tensor.to(self.device) embedded[field_name] = layer(tensor) if output_mask: # Generate a mask for this field # For transformers, we might already have a mask, or use pad token if masks is not None and field_name in masks: out_masks[field_name] = masks[field_name].to(self.device) elif hasattr(processor, "code_vocab"): pad_idx = processor.code_vocab.get("<pad>", 0) out_masks[field_name] = tensor != pad_idx else: # Default mask generation (e.g. for simple linear layers where 0 might be padding?) # Be careful changing this behavior. # Previous code: # masks[field_name] = (tensor != pad_idx) -> where pad_idx was 0 default pad_idx = 0 out_masks[field_name] = tensor != pad_idx if output_mask: return embedded, out_masks else: return embedded
def __repr__(self) -> str: return f"EmbeddingModel(embedding_layers={self.embedding_layers})"