Source code for pyhealth.processors.sequence_processor

from typing import Any, Dict, List, Iterable, Optional, Tuple

import torch

from . import register_processor
from .base_processor import FeatureProcessor, TokenProcessorInterface


[docs]@register_processor("sequence") class SequenceProcessor(FeatureProcessor, TokenProcessorInterface): """Feature processor for encoding categorical sequences. Encodes medical codes (e.g., diagnoses, procedures) into numerical indices. Supports single or multiple tokens and can build vocabulary on the fly if not provided. Args: code_mapping: optional tuple of (source_vocabulary, target_vocabulary) to map raw codes to a grouped vocabulary before tokenizing. Uses ``pyhealth.medcode.CrossMap`` internally. For example, ``("ICD9CM", "CCSCM")`` maps ~128K ICD-9 diagnosis codes to ~280 CCS categories, and ``("NDC", "ATC")`` maps ~940K drug codes to ~5K ATC categories. When None (default), codes are used as-is with no change to existing behavior. Examples: >>> proc = SequenceProcessor() # no mapping, same as before >>> proc = SequenceProcessor(code_mapping=("ICD9CM", "CCSCM")) """ def __init__(self, code_mapping: Optional[Tuple[str, str]] = None): self.code_vocab: Dict[Any, int] = {"<pad>": self.PAD, "<unk>": self.UNK} self._next_index = 2 self._mapper = None if code_mapping is not None: from pyhealth.medcode import CrossMap self._mapper = CrossMap.load(code_mapping[0], code_mapping[1]) def _map(self, token: str) -> List[str]: """Map a single token through the code mapping, if configured. Returns the token unchanged (as a single-element list) when no mapping is configured or when the token has no mapping. """ if self._mapper is None: return [token] mapped = self._mapper.map(token) return mapped if mapped else [token]
[docs] def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None: """Build vocabulary from samples, applying code mapping if set. Args: samples: iterable of sample dicts. field: key whose values are token lists. """ for sample in samples: for token in sample[field]: if token is None: continue # skip missing values for mapped in self._map(token): if mapped not in self.code_vocab: self.code_vocab[mapped] = self._next_index self._next_index += 1
[docs] def process(self, value: Any) -> torch.Tensor: """Process token value(s) into tensor of indices. Args: value: Raw token string or list of token strings. Returns: Tensor of indices. """ indices = [] for token in value: if token is None: continue # skip missing values, consistent with fit() for mapped in self._map(token): if mapped in self.code_vocab: indices.append(self.code_vocab[mapped]) else: indices.append(self.code_vocab["<unk>"]) return torch.tensor(indices, dtype=torch.long)
[docs] def remove(self, tokens: set[str]): """Remove specified vocabularies from the processor.""" keep = set(self.code_vocab.keys()) - tokens | {"<pad>", "<unk>"} order = [k for k, v in sorted(self.code_vocab.items(), key=lambda x: x[1]) if k in keep] self.code_vocab = { k : i for i, k in enumerate(order) }
[docs] def retain(self, tokens: set[str]): """Retain only the specified vocabularies in the processor.""" keep = set(self.code_vocab.keys()) & tokens | {"<pad>", "<unk>"} order = [k for k, v in sorted(self.code_vocab.items(), key=lambda x: x[1]) if k in keep] self.code_vocab = { k : i for i, k in enumerate(order) }
[docs] def add(self, tokens: set[str]): """Add specified vocabularies to the processor.""" i = len(self.code_vocab) for token in tokens: if token not in self.code_vocab: self.code_vocab[token] = i i += 1
[docs] def tokens(self) -> set[str]: """Return the set of tokens in the processor's vocabulary.""" return set(self.code_vocab.keys())
[docs] def vocab_size(self) -> int: """Return the size of the processor's vocabulary.""" return len(self.code_vocab)
[docs] def size(self): return len(self.code_vocab)
[docs] def is_token(self) -> bool: """Sequence codes are discrete token indices.""" return True
[docs] def schema(self) -> tuple[str, ...]: return ("value",)
[docs] def dim(self) -> tuple[int, ...]: """Output is a 1D tensor of code indices.""" return (1,)
[docs] def spatial(self) -> tuple[bool, ...]: return (True,)
def __repr__(self): return f"SequenceProcessor(code_vocab_size={len(self.code_vocab)})"