from functools import partial
from pathlib import Path
from typing import Any, List, Optional, Union
import torchvision.transforms as transforms
from PIL import Image
from . import register_processor
from .base_processor import FeatureProcessor
[docs]@register_processor("image")
class ImageProcessor(FeatureProcessor):
"""Feature processor for loading images from disk and converting them to tensors.
Args:
image_size: Desired output image size (will resize to square image).
Defaults to 224.
to_tensor: Whether to convert image to tensor. Defaults to True.
normalize: Whether to normalize image values to [0, 1]. Defaults to False.
mean: Precomputed mean for normalization.
std: Precomputed std for normalization.
mode: PIL image mode to convert the image to before processing.
Common modes: 'RGB', 'RGBA', 'L' (grayscale), 'P' (palette).
If None, keeps the original mode. Defaults to None.
Raises:
ValueError: If normalization parameters are inconsistent.
"""
def __init__(
self,
image_size: int = 224,
to_tensor: bool = True,
normalize: bool = False,
mean: Optional[List[float]] = None,
std: Optional[List[float]] = None,
mode: Optional[str] = None,
) -> None:
self.image_size = image_size
self.to_tensor = to_tensor
self.normalize = normalize
self.mean = mean
self.std = std
self.mode = mode
if self.normalize and (self.mean is None or self.std is None):
raise ValueError(
"Normalization requires both mean and std to be provided."
)
if not self.normalize and (self.mean is not None or self.std is not None):
raise ValueError(
"Mean and std are provided but normalize is set to False. "
"Either provide normalize=True, or remove mean and std."
)
self.transform = self._build_transform()
def _build_transform(self) -> transforms.Compose:
transform_list = []
if self.mode is not None:
transform_list.append(
transforms.Lambda(partial(_convert_mode, mode=self.mode))
)
if self.image_size is not None:
transform_list.append(
transforms.Resize((self.image_size, self.image_size))
)
if self.to_tensor:
transform_list.append(transforms.ToTensor())
if self.normalize:
transform_list.append(
transforms.Normalize(mean=self.mean, std=self.std)
)
return transforms.Compose(transform_list)
[docs] def process(self, value: Union[str, Path]) -> Any:
"""Process a single image path into a transformed tensor image.
Args:
value: Path to image file as string or Path object.
Returns:
Transformed image tensor.
Raises:
FileNotFoundError: If the image file does not exist.
"""
image_path = Path(value)
if not image_path.exists():
raise FileNotFoundError(f"Image file not found: {image_path}")
with Image.open(image_path) as img:
img.load() # Avoid "too many open files" errors
return self.transform(img)
[docs] def is_token(self) -> bool:
"""Image data is continuous (float-valued pixel intensities), not discrete tokens.
Returns:
False.
"""
return False
[docs] def schema(self) -> tuple[str, ...]:
"""Single tensor output.
Returns:
("value",)
"""
return ("value",)
[docs] def dim(self) -> tuple[int, ...]:
"""Output tensor has 3 dimensions: (C, H, W).
Returns:
(3,)
"""
return (3,)
[docs] def spatial(self) -> tuple[bool, ...]:
"""Spatial axes for the output tensor (C, H, W).
Channels are not spatial; height and width are.
Returns:
(False, True, True)
"""
return (False, True, True)
def __repr__(self) -> str:
return (
f"ImageLoadingProcessor(image_size={self.image_size}, "
f"to_tensor={self.to_tensor}, normalize={self.normalize}, "
f"mean={self.mean}, std={self.std}, mode={self.mode})"
)
def _convert_mode(img: Image.Image, mode: str) -> Image.Image:
"""Convert a PIL image to the requested mode."""
return img.convert(mode)