Source code for pyhealth.models.tfm_tokenizer

import math
from typing import Dict, Optional, Tuple, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# =============================================================================
# LAZY IMPORT FOR OPTIONAL DEPENDENCY
# =============================================================================
# The linear_attention_transformer package is only needed for TFMTokenizer and
# TFM_TOKEN_Classifier classes. However, this module is imported at package
# load time via pyhealth.models.__init__.py.
#
# Problem: If we import LinearAttentionTransformer at module level, users who
# don't need TFMTokenizer still get ImportError when they `import pyhealth.models`.
#
# Solution: Lazy import - only load the dependency when actually instantiating
# a class that needs it. This keeps the repo functional for the 95% of users
# who don't use TFM classes, while still providing clear error messages for
# the 5% who do but forgot to install the dependency.
#
# Why test failures are not an issue:
#   - Package imports work correctly (pyhealth.models loads w/ out error)
#   - Only those who instantiate TFMTokenizer see the ImportError
#   - Error message provides clear install instructions
#   - Tests in tests/core/test_tfm_tokenizer.py will fail w/ out the optional
#     dependency, but this is intentional behavior showing the lazy import works
# =============================================================================

LinearAttentionTransformer = None


def _get_linear_attention_transformer():
    """Lazily import LinearAttentionTransformer on first use.

    This function implements a lazy import pattern to avoid breaking the
    PyHealth package when the optional `linear_attention_transformer`
    dependency is not installed.

    Returns:
        The LinearAttentionTransformer class from the external package.

    Raises:
        ImportError: If the package is not installed, with a helpful
            message explaining how to install it.

    Why This Pattern:
        - pyhealth.models.__init__.py imports from this file at package load
        - A top-level `from linear_attention_transformer import ...` would
          cause ImportError for ALL users of pyhealth.models, even those
          who don't need TFMTokenizer
        - By deferring the import to class instantiation time, we ensure
          the error only occurs for users who actually try to use the
          affected classes
    """
    global LinearAttentionTransformer
    if LinearAttentionTransformer is None:
        try:
            from linear_attention_transformer import LinearAttentionTransformer as LAT
            LinearAttentionTransformer = LAT
        except ImportError:
            raise ImportError(
                "linear_attention_transformer is required for TFMTokenizer. "
                "Install it with: pip install linear-attention-transformer"
            )
    return LinearAttentionTransformer


from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel


def get_stft_torch(X, resampling_rate = 200):
    B,C,T = X.shape
    x_temp = rearrange(X, 'B C T -> (B C) T')
    window = torch.hann_window(resampling_rate).to(x_temp.device)
    x_stft_temp = torch.abs(torch.stft(x_temp, n_fft=resampling_rate, hop_length=resampling_rate//2, 
                          onesided = True,
                          return_complex=True, center = False,#normalized = True,
                          window = window)[:,:resampling_rate//2,:])
    
    x_stft_temp = rearrange(x_stft_temp, '(B C) F T -> B C F T', B=B)
    
    return x_stft_temp

class PositionalEncoding(nn.Module):
    """Positional encoding for transformer models.
    
    Args:
        d_model: dimension of the model embedding.
        dropout: dropout probability. Default is 0.1.
        max_len: maximum sequence length. Default is 1000.
    """

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        """Forward propagation.
        
        Args:
            x: input embeddings of shape (batch, max_len, d_model).
            
        Returns:
            output tensor of shape (batch, max_len, d_model).
        """
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


class TransformerEncoder(nn.Module):
    """Transformer encoder using linear attention.
    
    Args:
        emb_size: embedding size. Default is 64.
        num_heads: number of attention heads. Default is 8.
        depth: number of transformer layers. Default is 4.
        max_seq_len: maximum sequence length. Default is 1024.
    """

    def __init__(
        self,
        emb_size: int = 64,
        num_heads: int = 8,
        depth: int = 4,
        max_seq_len: int = 1024,
    ):
        super().__init__()

        LAT = _get_linear_attention_transformer()
        self.transformer = LAT(
            dim=emb_size,
            heads=num_heads,
            depth=depth,
            max_seq_len=max_seq_len,
            attn_layer_dropout=0.2,
            attn_dropout=0.2,
        )

    def forward(self, x):
        """Forward propagation.
        
        Args:
            x: input tensor of shape (batch, seq_len, emb_size).
            
        Returns:
            output tensor of shape (batch, seq_len, emb_size).
        """
        x = self.transformer(x)
        return x


def l2norm(t):
    return F.normalize(t, p=2, dim=-1)


class EMAVectorQuantizer(nn.Module):
    """Exponential Moving Average Vector Quantizer.
    
    Args:
        emb_size: dimensionality of embeddings.
        code_book_size: number of codebook entries.
        decay: exponential moving average decay factor. Default is 0.99.
        eps: small constant for numerical stability. Default is 1e-5.
    """

    def __init__(
        self, emb_size: int, code_book_size: int, decay: float = 0.99, eps: float = 1e-5
    ):
        super().__init__()
        self.emb_size = emb_size
        self.code_book_size = code_book_size
        self.decay = decay
        self.eps = eps

        self.embedding = nn.Embedding(code_book_size, emb_size)
        self.embedding.weight.data.uniform_(-1 / code_book_size, 1 / code_book_size)

        self.register_buffer("cluster_size", torch.zeros(code_book_size))
        self.register_buffer("ema_w", self.embedding.weight.data.clone())

    def forward(self, x):
        """Forward propagation.
        
        Args:
            x: input tensor of shape (B, T, emb_size).
            
        Returns:
            quantized: quantized vectors of shape (B, T, emb_size).
            encoding_indices: indices of selected codebook entries of shape (B, T).
        """
        flat_x = x.reshape(-1, self.emb_size)

        dist = (
            flat_x.pow(2).sum(dim=1, keepdim=True)
            - 2 * flat_x @ self.embedding.weight.t()
            + self.embedding.weight.pow(2).sum(dim=1, keepdim=True).t()
        )

        encoding_indices = torch.argmin(dist, dim=1)
        quantized = self.embedding(encoding_indices).view_as(x)

        if self.training:
            encodings_one_hot = F.one_hot(encoding_indices, self.code_book_size).type_as(
                flat_x
            )

            new_cluster_size = encodings_one_hot.sum(dim=0)
            self.cluster_size.data.mul_(self.decay).add_(
                new_cluster_size, alpha=1 - self.decay
            )

            dw = encodings_one_hot.t() @ flat_x
            self.ema_w.data.mul_(self.decay).add_(dw, alpha=1 - self.decay)

            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps)
                / (n + self.code_book_size * self.eps)
                * n
            )

            embed_normalized = self.ema_w / cluster_size.unsqueeze(1)
            self.embedding.weight.data.copy_(embed_normalized)

        encoding_indices = encoding_indices.reshape(x.size(0), x.size(1))
        return quantized, encoding_indices


def freq_bin_temporal_masking(
    X,
    freq_mask_ratio: float = 0.5,
    freq_bin_size: int = 5,
    time_mask_ratio: float = 0.5,
    time_bin_size: int = 10,
):
    """Apply frequency-bin and temporal masking to spectrograms.
    
    Args:
        X: input spectrogram of shape (B, F, T).
        freq_mask_ratio: ratio of frequency bins to mask. Default is 0.5.
        freq_bin_size: size of frequency bins. Default is 5.
        time_mask_ratio: ratio of time bins to mask. Default is 0.5.
        time_bin_size: size of time bins. Default is 10.
        
    Returns:
        X_masked: masked spectrogram (unmasked regions).
        X_masked_sym: inverse masked spectrogram (masked regions).
        full_mask: boolean mask for unmasked regions.
        full_mask_sym: boolean mask for masked regions.
    """
    B, F, T = X.shape

    num_freq_bins = F // freq_bin_size
    X_freq_binned = X.view(B, num_freq_bins, freq_bin_size, T)
    freq_mask = torch.ones_like(X_freq_binned)
    num_freq_bins_to_mask = int(num_freq_bins * freq_mask_ratio)
    freq_bins_to_mask = torch.randperm(num_freq_bins)[:num_freq_bins_to_mask]
    freq_mask[:, freq_bins_to_mask, ...] = 0
    freq_mask = freq_mask.view(B, F, T)

    num_time_bins = T // time_bin_size
    X_time_binned = X.view(B, F, num_time_bins, time_bin_size)
    time_mask = torch.ones_like(X_time_binned)
    num_time_bins_to_mask = int(num_time_bins * time_mask_ratio)
    time_bins_to_mask = torch.randperm(num_time_bins)[:num_time_bins_to_mask]
    time_mask[:, :, time_bins_to_mask, :] = 0
    time_mask = time_mask.view(B, F, T)

    full_mask = freq_mask * time_mask
    full_mask_sym = 1 - full_mask
    full_mask = full_mask.to(torch.bool)
    full_mask_sym = full_mask_sym.to(torch.bool)
    X_masked = X * full_mask
    X_masked_sym = X * full_mask_sym

    return X_masked, X_masked_sym, full_mask, full_mask_sym


[docs]class TFM_VQVAE2_deep(nn.Module): """TFM-Tokenizer module with raw EEG and STFT as input. Args: in_channels: number of input channels. Default is 1. n_freq: number of frequency bins in STFT. Default is 100. n_freq_patch: frequency patch size. Default is 5. emb_size: embedding dimension. Default is 64. code_book_size: size of the VQ codebook. Default is 8192. trans_freq_encoder_depth: depth of frequency encoder. Default is 4. trans_temporal_encoder_depth: depth of temporal encoder. Default is 4. trans_decoder_depth: depth of decoder. Default is 4. beta: weight for commitment loss. Default is 1.0. """ def __init__( self, in_channels: int = 1, n_freq: int = 100, n_freq_patch: int = 5, emb_size: int = 64, code_book_size: int = 8192, trans_freq_encoder_depth: int = 4, trans_temporal_encoder_depth: int = 4, trans_decoder_depth: int = 4, beta: float = 1.0, ): super().__init__() self.n_freq_patch = n_freq_patch self.emb_size = emb_size self.code_book_size = code_book_size # bin wise frequency embedding self.freq_patch_embedding = nn.Sequential( nn.Conv1d(in_channels, emb_size, kernel_size=n_freq_patch, stride=n_freq_patch), nn.GELU(), nn.GroupNorm(emb_size // 4, emb_size), nn.Conv1d(emb_size, emb_size, kernel_size=1, stride=1), nn.GELU(), nn.GroupNorm(emb_size // 4, emb_size), nn.Conv1d(emb_size, emb_size, kernel_size=1, stride=1), nn.GELU(), nn.GroupNorm(emb_size // 4, emb_size), ) # Freq Encoder self.trans_freq_encoder = TransformerEncoder( emb_size=emb_size, num_heads=8, depth=trans_freq_encoder_depth, max_seq_len=n_freq // n_freq_patch, ) # Temporal embedding self.temporal_patch_embedding = nn.Sequential( nn.Conv1d(in_channels, emb_size, kernel_size=200, stride=100), nn.GELU(), nn.GroupNorm(emb_size // 4, emb_size), nn.Conv1d(emb_size, emb_size, kernel_size=1, stride=1), nn.GELU(), nn.GroupNorm(emb_size // 4, emb_size), nn.Conv1d(emb_size, emb_size // 2, kernel_size=1, stride=1), nn.GELU(), nn.GroupNorm(emb_size // 4, emb_size // 2), ) # attention based aggregation global_freq_divider = n_freq // (n_freq_patch * n_freq_patch) self.freq_patch_embedding_2_atten = nn.Sequential( nn.Conv1d( emb_size, emb_size // (global_freq_divider * 2), kernel_size=n_freq_patch, stride=n_freq_patch, ), nn.Sigmoid(), ) self.freq_patch_embedding_2 = nn.Sequential( nn.Conv1d( emb_size, emb_size // (global_freq_divider * 2), kernel_size=n_freq_patch, stride=n_freq_patch, ), ) # Temporal Encoder self.trans_temporal_encoder = TransformerEncoder( emb_size=emb_size, num_heads=8, depth=trans_temporal_encoder_depth ) # Vector quantization bottleneck self.quantizer = EMAVectorQuantizer(emb_size, code_book_size) self.beta = beta # Decoder self.trans_decoder = TransformerEncoder( emb_size=emb_size, num_heads=8, depth=trans_decoder_depth ) # self.decoder = nn.Linear(emb_size, n_freq) self.decoder = nn.Sequential( nn.Linear(emb_size, emb_size), nn.Tanh(), nn.Linear(emb_size, n_freq) )
[docs] @torch.jit.ignore def no_weight_decay(self): return {"quantizer.embedding.weight"}
[docs] def tokenize(self, x, x_temporal): """Tokenize EEG signals into discrete tokens. Args: x: STFT spectrogram of shape (B, F, T). x_temporal: raw temporal signal of shape (B, n_samples). Returns: quant_out: quantized output. indices: discrete token indices. quant_in: input to quantizer (before quantization). """ B, F, T = x.shape x = x.permute(0, 2, 1).reshape(-1, 1, F) x = self.freq_patch_embedding(x) x = x.permute(0, 2, 1) x = self.trans_freq_encoder(x) x = x.permute(0, 2, 1) atten = self.freq_patch_embedding_2_atten(x) x = self.freq_patch_embedding_2(x) * atten x = x.reshape(-1, x.size(1) * x.size(2)) x = rearrange(x, "(B T) E -> B T E", T=T) x_temporal = x_temporal.unsqueeze(1) x_temporal = self.temporal_patch_embedding(x_temporal) x_temporal = rearrange(x_temporal, "B E T -> B T E") x = torch.cat((x, x_temporal), dim=-1) x = self.trans_temporal_encoder(x) quant_in = l2norm(x) quant_out, indices = self.quantizer(quant_in) return quant_out, indices, quant_in
[docs] def forward(self, x, x_temporal): """Forward propagation. Args: x: STFT spectrogram of shape (B, F, T). x_temporal: raw temporal signal of shape (B, n_samples). Returns: x: reconstructed STFT spectrogram. indices: discrete token indices. quant_out: quantized output. quant_in: input to quantizer. """ quant_out, indices, quant_in = self.tokenize(x, x_temporal) quant_out = quant_in + (quant_out - quant_in).detach() x = self.trans_decoder(quant_out) x = self.decoder(x) x = x.permute(0, 2, 1) return x, indices, quant_out, quant_in
[docs] def vec_quantizer_loss(self, quant_in, quant_out): """Compute vector quantizer losses. Args: quant_in: input to quantizer. quant_out: output from quantizer. Returns: loss: total VQ loss. code_book_loss: codebook loss component. commitment_loss: commitment loss component. """ commitment_loss = torch.mean((quant_out.detach() - quant_in) ** 2) code_book_loss = torch.mean((quant_out - quant_in.detach()) ** 2) loss = code_book_loss + self.beta * commitment_loss return loss, code_book_loss, commitment_loss
[docs] @torch.no_grad() def forward_ana(self, x, x_temporal): """Forward propagation with intermediate outputs for analysis. Returns: x_dec: reconstructed output. indices: quantizer indices. quant_out: quantized representation. quant_in: input to quantizer. freq_encoded: frequency encoder tokens. temporal_encoded: temporal encoder tokens. """ B, F, T = x.shape x_freq = x.permute(0, 2, 1).reshape(-1, 1, F) x_freq = self.freq_patch_embedding(x_freq) x_freq = x_freq.permute(0, 2, 1) freq_encoded = self.trans_freq_encoder(x_freq) x_freq_agg = freq_encoded.permute(0, 2, 1) atten = self.freq_patch_embedding_2_atten(x_freq_agg) x_freq_agg = self.freq_patch_embedding_2(x_freq_agg) * atten x_freq_agg = x_freq_agg.reshape(-1, x_freq_agg.size(1) * x_freq_agg.size(2)) x_freq_agg = rearrange(x_freq_agg, "(B T) E -> B T E", T=T) x_temporal_branch = x_temporal.unsqueeze(1) x_temporal_branch = self.temporal_patch_embedding(x_temporal_branch) x_temporal_branch = rearrange(x_temporal_branch, "B E T -> B T E") x_combined = torch.cat((x_freq_agg, x_temporal_branch), dim=-1) temporal_encoded = self.trans_temporal_encoder(x_combined) quant_in = l2norm(temporal_encoded) quant_out, indices = self.quantizer(quant_in) quant_out = quant_in + (quant_out - quant_in).detach() x_dec = self.trans_decoder(quant_out) x_dec = self.decoder(x_dec) x_dec = x_dec.permute(0, 2, 1) return x_dec, indices, quant_out, quant_in, freq_encoded, temporal_encoded
[docs]class TFM_TOKEN_Classifier(nn.Module): """Downstream classifier using TFM tokens. Args: emb_size: embedding dimension. Default is 256. code_book_size: size of the VQ codebook. Default is 8192. num_heads: number of attention heads. Default is 8. depth: number of transformer layers. Default is 12. max_seq_len: maximum sequence length. Default is 61. n_classes: number of output classes. Default is 5. """ def __init__( self, emb_size: int = 256, code_book_size: int = 8192, num_heads: int = 8, depth: int = 12, max_seq_len: int = 61, n_classes: int = 5, ): super().__init__() self.eeg_token_embedding = nn.Embedding(code_book_size + 1, emb_size) self.channel_embed = nn.Embedding(16, emb_size) self.index = nn.Parameter(torch.LongTensor(range(16)), requires_grad=False) self.temporal_pos_embed = PositionalEncoding(emb_size) self.pos_drop = nn.Dropout(p=0.1) self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_size)) LAT = _get_linear_attention_transformer() self.LAT = LAT( dim=emb_size, heads=num_heads, depth=depth, max_seq_len=max_seq_len, attn_layer_dropout=0.2, attn_dropout=0.2, ) self.classification_head = nn.Linear(emb_size, n_classes)
[docs] def forward(self, x, num_ch: int = 16): """Forward propagation. Args: x: token indices of shape (B, C, T). num_ch: number of channels. Default is 16. Returns: pred: class predictions of shape (B, n_classes). """ x = self.eeg_token_embedding(x) for i in range(x.shape[1]): used_channel_embed = ( self.channel_embed(self.index[i]) .unsqueeze(0) .unsqueeze(0) .expand(x.size(0), -1, -1) ) x[:, i] = self.temporal_pos_embed(x[:, i] + used_channel_embed) x = rearrange(x, "B C T E -> B (C T) E") cls_tokens = self.cls_token.expand(x.size(0), -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = self.LAT(x) pred = self.classification_head(x[:, 0]) return pred
[docs] def masked_prediction(self, x, num_ch: int = 16): """Forward propagation with masked prediction (all tokens). Args: x: token indices of shape (B, C, T). num_ch: number of channels. Default is 16. Returns: pred: predictions for all tokens (excluding CLS). """ x = self.eeg_token_embedding(x) for i in range(x.shape[1]): used_channel_embed = ( self.channel_embed(self.index[i]) .unsqueeze(0) .unsqueeze(0) .expand(x.size(0), -1, -1) ) x[:, i] = self.temporal_pos_embed(x[:, i] + used_channel_embed) x = rearrange(x, "B C T E -> B (C T) E") cls_tokens = self.cls_token.expand(x.size(0), -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = self.LAT(x) pred = self.classification_head(x[:, 1:]) return pred
[docs] @torch.jit.ignore def no_weight_decay(self): return {"temporal_pos_embed", "cls_token"}
[docs]def get_tfm_tokenizer_2x2x8(code_book_size: int = 8192, emb_size: int = 64): """Create TFM-Tokenizer with 2x2x8 architecture. Args: code_book_size: size of the VQ codebook. Default is 8192. emb_size: embedding dimension. Default is 64. Returns: TFM_VQVAE2_deep model instance. """ vqvae = TFM_VQVAE2_deep( in_channels=1, n_freq=100, n_freq_patch=5, emb_size=emb_size, code_book_size=code_book_size, trans_freq_encoder_depth=2, trans_temporal_encoder_depth=2, trans_decoder_depth=8, beta=1.0, ) return vqvae
[docs]def get_tfm_token_classifier_64x4( n_classes: int = 5, code_book_size: int = 8192, emb_size: int = 64 ): """Create TFM-Token classifier with 64x4 architecture. Args: n_classes: number of output classes. Default is 5. code_book_size: size of the VQ codebook. Default is 8192. emb_size: embedding dimension. Default is 64. Returns: TFM_TOKEN_Classifier model instance. """ classifier = TFM_TOKEN_Classifier( emb_size=emb_size, code_book_size=code_book_size, num_heads=8, depth=4, max_seq_len=2048, n_classes=n_classes, ) return classifier
[docs]def load_embedding_weights(source_model, target_model): """Load embedding weights from tokenizer to classifier. Args: source_model: the tokenizer model (TFM_VQVAE2_deep). target_model: the classifier model (TFM_TOKEN_Classifier). """ source_weights = source_model.quantizer.embedding.weight.data target_weights = target_model.eeg_token_embedding.weight.data src_vocab_size, src_emb_dim = source_weights.shape tgt_vocab_size, tgt_emb_dim = target_weights.shape print(f"Source Embedding Shape: {source_weights.shape}") print(f"Target Embedding Shape: {target_weights.shape}") if src_emb_dim != tgt_emb_dim: raise ValueError( f"Embedding size mismatch: {src_emb_dim} (source) vs {tgt_emb_dim} (target)" ) if src_vocab_size > tgt_vocab_size: adapted_weights = source_weights[:tgt_vocab_size, :] print(f"Trimming source embeddings from {src_vocab_size} to {tgt_vocab_size}") elif src_vocab_size < tgt_vocab_size: adapted_weights = torch.zeros( (tgt_vocab_size, tgt_emb_dim), dtype=source_weights.dtype ) adapted_weights[:src_vocab_size, :] = source_weights print(f"Padding source embeddings from {src_vocab_size} to {tgt_vocab_size}") else: adapted_weights = source_weights target_model.eeg_token_embedding.weight.data.copy_(adapted_weights) print("Successfully loaded embedding weights!")
[docs]class TFMTokenizer(BaseModel): """TFM-Tokenizer model. This model uses VQ-VAE with transformers to tokenize EEG signals. It can extract discrete tokens and continuous embeddings for downstream tasks. The model expects two inputs: - STFT spectrogram: shape (batch, n_freq, n_time) - Raw temporal signal: shape (batch, n_samples) Args: dataset: the dataset to train the model. emb_size: embedding dimension. Default is 64. code_book_size: size of the VQ codebook. Default is 8192. n_freq: number of frequency bins in STFT. Default is 100. n_freq_patch: frequency patch size. Default is 5. trans_freq_encoder_depth: depth of frequency encoder. Default is 2. trans_temporal_encoder_depth: depth of temporal encoder. Default is 2. trans_decoder_depth: depth of decoder. Default is 8. use_classifier: whether to use the classifier head. Default is True. classifier_depth: depth of classifier transformer. Default is 4. classifier_heads: number of attention heads in classifier. Default is 8. Examples: >>> from pyhealth.datasets import TUEVDataset >>> from pyhealth.models import TFMTokenizer >>> dataset = TUEVDataset(root="/path/to/tuev") >>> sample_dataset = dataset.set_task() >>> model = TFMTokenizer(dataset=sample_dataset) >>> model.load_pretrained_weights("tfm_encoder_best_model.pth") """ def __init__( self, dataset: SampleDataset, emb_size: int = 64, code_book_size: int = 8192, n_freq: int = 100, n_freq_patch: int = 5, trans_freq_encoder_depth: int = 2, trans_temporal_encoder_depth: int = 2, trans_decoder_depth: int = 8, use_classifier: bool = True, classifier_depth: int = 4, classifier_heads: int = 8, **kwargs, ): super().__init__(dataset=dataset) self.emb_size = emb_size self.code_book_size = code_book_size self.use_classifier = use_classifier self.tokenizer = TFM_VQVAE2_deep( in_channels=1, n_freq=n_freq, n_freq_patch=n_freq_patch, emb_size=emb_size, code_book_size=code_book_size, trans_freq_encoder_depth=trans_freq_encoder_depth, trans_temporal_encoder_depth=trans_temporal_encoder_depth, trans_decoder_depth=trans_decoder_depth, beta=1.0, ) if use_classifier: output_size = self.get_output_size() self.classifier = TFM_TOKEN_Classifier( emb_size=emb_size, code_book_size=code_book_size, num_heads=classifier_heads, depth=classifier_depth, max_seq_len=2048, n_classes=output_size, )
[docs] def forward(self, **kwargs) -> Dict[str, torch.Tensor]: """Forward propagation. Args: **kwargs: keyword arguments containing 'stft', 'signal', and label key. Returns: a dictionary containing loss, y_prob, y_true, logit, tokens, embeddings. """ stft = kwargs.get("stft") signal = kwargs.get("signal") if stft is None or signal is None: raise ValueError("Both 'stft' and 'signal' must be provided in inputs") if len(signal.shape) == 2: signal = signal.unsqueeze(0) B,C,T = signal.shape stft = stft.to(self.device) signal = signal.to(self.device) stft = rearrange(stft, 'B C F T -> (B C) F T') signal = rearrange(signal, 'B C T -> (B C) T') reconstructed, tokens, quant_out, quant_in = self.tokenizer(stft, signal) recon_loss = F.mse_loss(reconstructed, stft) vq_loss, _, _ = self.tokenizer.vec_quantizer_loss(quant_in, quant_out) tokens_reshaped = rearrange(tokens, '(B C) T -> B C T', C=C) quant_out_reshaped = rearrange(quant_out, '(B C) T E -> B C T E', C=C) results = { "recon_loss": recon_loss, "vq_loss": vq_loss, "tokens": tokens_reshaped, "embeddings": quant_out_reshaped, } if self.use_classifier and len(self.label_keys) > 0: label_key = self.label_keys[0] y_true = kwargs[label_key].to(self.device) # Reshape tokens to (B, C, T) for multi-channel classifier logits = self.classifier(tokens_reshaped, num_ch=C) loss_fn = self.get_loss_function() cls_loss = loss_fn(logits, y_true) total_loss = recon_loss + vq_loss + cls_loss y_prob = self.prepare_y_prob(logits) results.update( { "loss": total_loss, "cls_loss": cls_loss, "y_prob": y_prob, "y_true": y_true, "logit": logits, } ) else: results["loss"] = recon_loss + vq_loss if kwargs.get("embed", False): # Mean-pool over channels (C) and time steps (T) → (B, emb_size) # quant_out_reshaped: (B, C, T, E) results["embed"] = quant_out_reshaped.mean(dim=(1, 2)) return results
[docs] def get_embeddings(self, dataloader) -> torch.Tensor: """Extract continuous embeddings for all samples in a dataloader. Args: dataloader: PyHealth dataloader. Returns: tensor of shape (n_samples, seq_len, emb_size). """ self.eval() all_embeddings = [] with torch.no_grad(): for batch in dataloader: signal = batch.get("signal").to(self.device) stft = batch.get("stft").to(self.device) if len(signal.shape) == 2: signal = signal.unsqueeze(0) B,C,T = signal.shape stft = rearrange(stft, 'B C F T -> (B C) F T') signal = rearrange(signal, 'B C T -> (B C) T') _, _, quant_out, _ = self.tokenizer(stft, signal) print(f"quant_out shape: {quant_out.shape}") quant_out = rearrange(quant_out, '(B C) T E -> B C T E', C=C) print(f"quant_out shape: {quant_out.shape}") all_embeddings.append(quant_out.cpu()) return torch.cat(all_embeddings, dim=0)
[docs] def get_tokens(self, dataloader) -> torch.Tensor: """Extract discrete tokens for all samples in a dataloader. Args: dataloader: PyHealth dataloader. Returns: tensor of shape (n_samples, seq_len). """ self.eval() all_tokens = [] with torch.no_grad(): for batch in dataloader: signal = batch.get("signal").to(self.device) stft = batch.get("stft").to(self.device) if len(signal.shape) == 2: signal = signal.unsqueeze(0) B,C,T = signal.shape stft = rearrange(stft, 'B C F T -> (B C) F T') signal = rearrange(signal, 'B C T -> (B C) T') _, tokens, _, _ = self.tokenizer(stft, signal) tokens = rearrange(tokens, '(B C) T -> B C T', C=C) all_tokens.append(tokens.cpu()) return torch.cat(all_tokens, dim=0)
[docs] def load_pretrained_weights( self, tokenizer_checkpoint_path: str, classifier_checkpoint_path: str = None, is_masked_training: bool = False, strict: bool = False, map_location: str = None ): """Load pre-trained weights from checkpoint. Args: tokenizer_checkpoint_path: path to the tokenizer checkpoint file. classifier_checkpoint_path: path to the classifier checkpoint file. strict: whether to strictly enforce key matching. Default is True. map_location: device to map the loaded tensors. Default is None. """ if map_location is None: map_location = str(self.device) # Load tokenizer weights self.tokenizer.load_state_dict(torch.load(tokenizer_checkpoint_path, map_location=map_location, weights_only=True), strict=strict) if classifier_checkpoint_path is not None and not is_masked_training: ckpt = torch.load(classifier_checkpoint_path, map_location=map_location, weights_only=True) model_n = self.classifier.classification_head.weight.shape[0] ckpt_n = ckpt["classification_head.weight"].shape[0] if ckpt_n != model_n: if ckpt_n == 1 and model_n == 2: # Checkpoint was trained with 1-class sigmoid BCE; model expects # 2-class softmax. The conversion is exact: # softmax([-logit, logit]) = [1-sigmoid(logit), sigmoid(logit)] w = ckpt["classification_head.weight"] # [1, D] b = ckpt["classification_head.bias"] # [1] ckpt["classification_head.weight"] = torch.cat([-w, w], dim=0) ckpt["classification_head.bias"] = torch.cat([-b, b], dim=0) print(f" ℹ Adapted classifier head from 1-class sigmoid → 2-class softmax") else: raise RuntimeError( f"Classifier head shape mismatch: checkpoint has {ckpt_n} class(es) " f"but model expects {model_n}. Cannot auto-adapt." ) self.classifier.load_state_dict(ckpt) print(f"✓ Successfully loaded weights from {classifier_checkpoint_path}") elif is_masked_training: load_embedding_weights(self.tokenizer, self.classifier) print("✓ Successfully loaded embedding weights!") else: print(f"No classifier checkpoint path provided. Skipping classifier weight loading.")
if __name__ == "__main__": print("Testing TFM-Tokenizer components...") tokenizer = get_tfm_tokenizer_2x2x8() print(f"✓ Created tokenizer: {tokenizer.__class__.__name__}") classifier = get_tfm_token_classifier_64x4(n_classes=6) print(f"✓ Created classifier: {classifier.__class__.__name__}") batch_size = 2 n_freq = 100 n_time = 60 n_samples = 1280 dummy_stft = torch.randn(batch_size, n_freq, n_time) dummy_signal = torch.randn(batch_size, n_samples) recon, tokens, quant_out, quant_in = tokenizer(dummy_stft, dummy_signal) print(f"✓ Tokenizer forward pass:") print(f" Reconstructed shape: {recon.shape}") print(f" Tokens shape: {tokens.shape}") print(f" Embeddings shape: {quant_out.shape}") preds = classifier(tokens) print(f"✓ Classifier forward pass:") print(f" Predictions shape: {preds.shape}") print("\n✓ All tests passed!")