from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Iterable
import torch
[docs]class ModalityType(str, Enum):
"""Standard modality identifiers for routing in UnifiedMultimodalEmbeddingModel.
Using ``str, Enum`` so values serialise cleanly (e.g. JSON / pickle).
"""
CODE = "code" # Discrete ICD / medication / procedure codes
TEXT = "text" # Clinical notes, reports (tokenised to int tensors)
IMAGE = "image" # Medical images (X-ray, CT slice, etc.)
NUMERIC = "numeric" # Lab values, vitals, continuous measurements
AUDIO = "audio" # Heart/lung sounds, speech waveforms
SIGNAL = "signal" # ECG, EEG time-series waveforms
class Processor(ABC):
"""
Abstract base processor class.
Defines optional hooks for saving/loading state to/from disk.
"""
def save(self, path: str) -> None:
"""Optional: Save processor state to disk.
Args:
path: File path to save processor state.
"""
pass
def load(self, path: str) -> None:
"""Optional: Load processor state from disk.
Args:
path: File path to load processor state from.
"""
pass
[docs]class FeatureProcessor(Processor):
"""
Processor for individual fields (features).
Example: Tokenization, image loading, normalization.
"""
[docs] def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
"""Fit the processor to the samples.
Args:
samples: List of sample dictionaries.
"""
pass
[docs] @abstractmethod
def process(self, value: Any) -> Any:
"""Process an individual field value.
Args:
value: Raw field value.
Returns:
Processed value.
"""
pass
[docs] def is_token(self) -> bool:
"""Returns whether the output (in particular, the value tensor) of the processor
represents discrete token indices (True) or continuous values (False). This is used to
determine whether to apply token-based transformations (e.g. `nn.Embedding`) or
value-based augmentations (e.g. `nn.Linear`).
Returns:
True if the output of the processor represents discrete token indices, False otherwise.
"""
raise NotImplementedError("is_token method is not implemented for this processor.")
[docs] def schema(self) -> tuple[str, ...]:
"""Returns the schema of the processed feature. For a processor that emits a single tensor,
this should just return `["value"]`. For a processor that emits a tuple of tensors,
this should return a tuple of the same length as the tuple, with the semantic name of each tensor,
such as `["time", "value"]`, `["value", "mask"]`, etc.
Typical semantic names include:
- "value": the main processed tensor output of the processor
- "time": the time tensor output of the processor (mostly for StageNet)
- "mask": the mask tensor output of the processor (if applicable)
Returns:
Tuple of semantic names corresponding to the output of the processor.
"""
raise NotImplementedError("Schema method is not implemented for this processor.")
[docs] def dim(self) -> tuple[int, ...]:
"""Number of dimensions (`Tensor.dim()`) for each output
tensor, in the same order as the output tuple.
Returns:
Tuple of integers corresponding to the number of dimensions of each output tensor.
"""
raise NotImplementedError("dim method is not implemented for this processor.")
[docs] def spatial(self) -> tuple[bool, ...]:
"""Whether each dimension (axis) of the value tensor is spatial (i.e. corresponds to a spatial
axis like time, height, width, etc.) or not. This is used to determine how to apply
augmentations and other transformations that should only be applied to spatial dimensions.
E.g. for CNN or RNN features, this would help determine which dimensions to apply spatial augmentations to,
and which dimensions to treat as channels or features.
Returns:
Tuple of booleans corresponding to whether each axis of the value tensor is spatial or not.
"""
raise NotImplementedError("spatial method is not implemented for this processor.")
class SampleProcessor(Processor):
"""
Processor for individual samples (dict of fields).
Example: Imputation, sample-level augmentation, label smoothing.
"""
@abstractmethod
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single sample dictionary.
Args:
sample: Sample dictionary.
Returns:
Processed sample dictionary.
"""
pass
class DatasetProcessor(Processor):
"""
Processor for the full dataset.
Example: Global normalization, train/val/test splitting, dataset caching.
"""
@abstractmethod
def process(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Process the entire dataset.
Args:
samples: List of sample dictionaries.
Returns:
List of processed sample dictionaries.
"""
pass
class TokenProcessorInterface(ABC):
"""
Base class for feature processors that build a vocabulary.
Provides a common interface for accessing vocabulary-related information.
"""
PAD = 0
UNK = 1
@abstractmethod
def remove(self, tokens: set[str]):
"""Remove specified vocabularies from the processor."""
pass
@abstractmethod
def retain(self, tokens: set[str]):
"""Retain only the specified vocabularies in the processor."""
pass
@abstractmethod
def add(self, tokens: set[str]):
"""Add specified vocabularies to the processor."""
pass
@abstractmethod
def tokens(self) -> set[str]:
"""Return the set of tokens in the processor's vocabulary."""
pass
@abstractmethod
def vocab_size(self) -> int:
"""Return the size of the processor's vocabulary."""
pass
[docs]class TemporalFeatureProcessor(FeatureProcessor):
"""Abstract base class for processors whose features are paired with timestamps.
**Contract** — every subclass must implement:
- ``modality() -> ModalityType`` — what kind of data this processor handles.
- ``value_dim() -> int`` — size of the raw value vector *before* any
learned embedding (e.g. vocab_size for codes, n_features for numerics).
- ``process(value) -> dict[str, torch.Tensor]`` — must return a dict with at
least the keys ``"value"`` and ``"time"``, and optionally ``"mask"``.
**Backward compatibility** — the existing ``FeatureProcessor`` API
(``is_token``, ``schema``, ``dim``, ``spatial``) is *kept* on the parent class
and continues to work for all non-temporal processors. Subclasses of
``TemporalFeatureProcessor`` should still implement those methods if they want
to remain compatible with the existing ``EmbeddingModel`` / ``MultimodalRNN``
pipeline. The new ``modality()`` / ``value_dim()`` API is *additive* — used
exclusively by ``UnifiedMultimodalEmbeddingModel``.
**Why dict output?**
================ ======================== ==================================
Concern Tuple (current) Dict (this class)
================ ======================== ==================================
Collation Custom per arity Generic: stack/pad per key
litdata List[str] breaks All values tensors/scalars ✓
Schema Positional, fragile Named keys, self-documenting
Extensibility Adding field = new arity Adding key = backward-compat
================ ======================== ==================================
"""
# ── New API (required by UnifiedMultimodalEmbeddingModel) ─────────────────
[docs] @abstractmethod
def modality(self) -> ModalityType:
"""Return the modality type of the data this processor handles."""
...
[docs] @abstractmethod
def value_dim(self) -> int:
"""Dimensionality of the raw value vector *before* learned embedding.
For codes: ``vocab_size`` (used with ``nn.Embedding``)
For images: ``C * H * W`` (used with CNN encoder)
For numerics: ``n_features`` (used with ``nn.Linear``)
For text: ``vocab_size`` (used with transformer encoder)
"""
...
[docs] @abstractmethod
def process(self, value) -> dict[str, torch.Tensor]:
"""Process raw input and return a dict of tensors.
Required keys:
``"value"`` — main feature tensor.
``"time"`` — 1-D float32 tensor, one timestamp per event.
Optional keys:
``"mask"`` — validity / attention mask for ``"value"``.
"""
...
[docs] def schema(self) -> tuple[str, ...]:
"""Standardised schema: at minimum ``('value', 'time')``."""
return ("value", "time")