import logging
from typing import Any, Dict, Iterable
import torch
from . import register_processor
from .base_processor import FeatureProcessor
logger = logging.getLogger(__name__)
[docs]@register_processor("binary")
class BinaryLabelProcessor(FeatureProcessor):
"""
Processor for binary classification labels.
"""
def __init__(self):
super().__init__()
self.label_vocab: Dict[Any, int] = {}
[docs] def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
all_labels = set([sample[field] for sample in samples])
if len(all_labels) != 2:
raise ValueError(f"Expected 2 unique labels, got {len(all_labels)}")
if all_labels == {0, 1}:
self.label_vocab = {0: 0, 1: 1}
elif all_labels == {False, True}:
self.label_vocab = {False: 0, True: 1}
else:
all_labels = list(all_labels)
all_labels.sort()
self.label_vocab = {label: i for i, label in enumerate(all_labels)}
logger.info(f"Label {field} vocab: {self.label_vocab}")
[docs] def process(self, value: Any) -> torch.Tensor:
index = self.label_vocab[value]
return torch.tensor([index], dtype=torch.float32)
[docs] def size(self):
return 1
[docs] def is_token(self) -> bool:
"""Binary labels are continuous float targets for BCE loss."""
return False
[docs] def schema(self) -> tuple[str, ...]:
return ("value",)
[docs] def dim(self) -> tuple[int, ...]:
"""Output shape is (1,), so 1 dimension."""
return (1,)
[docs] def spatial(self) -> tuple[bool, ...]:
return (False,)
def __repr__(self):
return f"BinaryLabelProcessor(label_vocab_size={len(self.label_vocab)})"
[docs]@register_processor("multiclass")
class MultiClassLabelProcessor(FeatureProcessor):
"""
Processor for multi-class classification labels.
"""
def __init__(self):
super().__init__()
self.label_vocab: Dict[Any, int] = {}
[docs] def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
all_labels = set([sample[field] for sample in samples])
num_classes = len(all_labels)
if all_labels == set(range(num_classes)):
self.label_vocab = {i: i for i in range(num_classes)}
else:
all_labels = list(all_labels)
all_labels.sort()
self.label_vocab = {label: i for i, label in enumerate(all_labels)}
logger.info(f"Label {field} vocab: {self.label_vocab}")
[docs] def process(self, value: Any) -> torch.Tensor:
index = self.label_vocab[value]
return torch.tensor(index, dtype=torch.long)
[docs] def size(self):
return len(self.label_vocab)
[docs] def is_token(self) -> bool:
"""Multi-class labels are discrete token indices."""
return True
[docs] def schema(self) -> tuple[str, ...]:
return ("value",)
[docs] def dim(self) -> tuple[int, ...]:
"""Output is a scalar tensor (dim 0)."""
return (0,)
[docs] def spatial(self) -> tuple[bool, ...]:
return ()
def __repr__(self):
return f"MultiClassLabelProcessor(label_vocab_size={len(self.label_vocab)})"
[docs]@register_processor("multilabel")
class MultiLabelProcessor(FeatureProcessor):
"""
Processor for multi-label classification labels.
Args:
num_classes (int): Number of classes.
"""
def __init__(self):
super().__init__()
self.label_vocab: Dict[Any, int] = {}
[docs] def fit(self, samples: Iterable[Dict[str, Any]], field: str) -> None:
all_labels = set()
for sample in samples:
for label in sample[field]:
all_labels.add(label)
num_classes = len(all_labels)
if all_labels == set(range(num_classes)):
self.label_vocab = {i: i for i in range(num_classes)}
else:
all_labels = list(all_labels)
all_labels.sort()
self.label_vocab = {label: i for i, label in enumerate(all_labels)}
logger.info(f"Label {field} vocab: {self.label_vocab}")
[docs] def process(self, value: Any) -> torch.Tensor:
if not isinstance(value, list):
raise ValueError("Expected a list of labels for multilabel task.")
target = torch.zeros(len(self.label_vocab), dtype=torch.float32)
for label in value:
index = self.label_vocab[label]
target[index] = 1.0
return target
[docs] def size(self):
return len(self.label_vocab)
[docs] def is_token(self) -> bool:
"""Multi-label indicators are continuous float targets for BCE loss."""
return False
[docs] def schema(self) -> tuple[str, ...]:
return ("value",)
[docs] def dim(self) -> tuple[int, ...]:
"""Output shape is (num_classes,), so 1 dimension."""
return (1,)
[docs] def spatial(self) -> tuple[bool, ...]:
return (False,)
def __repr__(self):
return f"MultiLabelProcessor(label_vocab_size={len(self.label_vocab)})"
[docs]@register_processor("regression")
class RegressionLabelProcessor(FeatureProcessor):
"""
Processor for regression labels.
"""
[docs] def process(self, value: Any) -> torch.Tensor:
return torch.tensor([float(value)], dtype=torch.float32)
[docs] def size(self):
return 1
[docs] def is_token(self) -> bool:
"""Regression labels are continuous, not discrete tokens."""
return False
[docs] def schema(self) -> tuple[str, ...]:
return ("value",)
[docs] def dim(self) -> tuple[int, ...]:
"""Output shape is (1,), so 1 dimension."""
return (1,)
[docs] def spatial(self) -> tuple[bool, ...]:
return (False,)
def __repr__(self):
return "RegressionLabelProcessor()"