from collections.abc import Sequence
from pathlib import Path
import pickle
import shutil
import tempfile
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type
import inspect
import random
from bisect import bisect_right
import litdata
from litdata.utilities.train_test_split import deepcopy_dataset
import copy
from ..processors import get_processor, IgnoreProcessor
from ..processors.base_processor import FeatureProcessor
class SampleBuilder:
"""Fit feature processors and transform pickled samples without materializing a dataset.
SampleBuilder is a lightweight helper used to:
- Fit feature processors from provided `input_schema` and `output_schema` on an
iterable of raw Python sample dictionaries.
- Build mappings from patient IDs and record IDs to sample indices.
- Transform pickled sample records into processed feature dictionaries using
the fitted processors.
Typical usage:
builder = SampleBuilder(input_schema, output_schema)
builder.fit(samples)
builder.save(path) # writes a schema.pkl metadata file
After saving the schema, `litdata.optimize` can be used with `builder.transform`
to serialize and chunk pickled sample items into a directory that can be
loaded via SampleDataset.
"""
def __init__(
self,
input_schema: Dict[str, Any],
output_schema: Dict[str, Any],
input_processors: Optional[Dict[str, FeatureProcessor]] = None,
output_processors: Optional[Dict[str, FeatureProcessor]] = None,
) -> None:
self.input_schema = input_schema
self.output_schema = output_schema
self._input_processors = (
input_processors if input_processors is not None else {}
)
self._output_processors = (
output_processors if output_processors is not None else {}
)
self._patient_to_index: Dict[str, List[int]] = {}
self._record_to_index: Dict[str, List[int]] = {}
self._fitted = False
@property
def input_processors(self) -> Dict[str, FeatureProcessor]:
if not self._fitted:
raise RuntimeError(
"SampleBuilder.fit must be called before accessing input_processors."
)
return self._input_processors
@property
def output_processors(self) -> Dict[str, FeatureProcessor]:
if not self._fitted:
raise RuntimeError(
"SampleBuilder.fit must be called before accessing output_processors."
)
return self._output_processors
@property
def patient_to_index(self) -> Dict[str, List[int]]:
if not self._fitted:
raise RuntimeError(
"SampleBuilder.fit must be called before accessing patient_to_index."
)
return self._patient_to_index
@property
def record_to_index(self) -> Dict[str, List[int]]:
if not self._fitted:
raise RuntimeError(
"SampleBuilder.fit must be called before accessing record_to_index."
)
return self._record_to_index
def _get_processor_instance(self, processor_spec):
"""Instantiate a processor using the same resolution logic as SampleDataset."""
if isinstance(processor_spec, tuple):
spec, kwargs = processor_spec
if isinstance(spec, str):
return get_processor(spec)(**kwargs)
if inspect.isclass(spec) and issubclass(spec, FeatureProcessor):
return spec(**kwargs)
raise ValueError(
"Processor spec in tuple must be either a string alias or a "
f"FeatureProcessor class, got {type(spec)}"
)
if isinstance(processor_spec, str):
return get_processor(processor_spec)()
if inspect.isclass(processor_spec) and issubclass(
processor_spec, FeatureProcessor
):
return processor_spec()
if isinstance(processor_spec, FeatureProcessor):
return processor_spec
raise ValueError(
"Processor spec must be either a string alias, a FeatureProcessor "
f"class, or a tuple (spec, kwargs_dict), got {type(processor_spec)}"
)
def fit(
self,
samples: Iterable[Dict[str, Any]],
) -> None:
"""Fit processors and build mapping indices from an iterator of samples.
Args:
samples: Iterable of sample dictionaries (e.g., python dicts). Each
sample should contain keys covering both the configured
`input_schema` and `output_schema`. These samples are not
required to be pickled; `fit` operates on in-memory dicts.
Behavior:
- Validates the samples contain all keys specified by the input
and output schemas.
- Builds `patient_to_index` and `record_to_index` mappings by
recording the sample indices associated with `patient_id` and
`record_id`/`visit_id` fields.
- Instantiates and fits input/output processors from the provided
schemas (unless pre-fitted processors were supplied to the
constructor).
"""
# Validate the samples
input_keys = set(self.input_schema.keys())
output_keys = set(self.output_schema.keys())
for sample in samples:
assert input_keys.issubset(
sample.keys()
), "Input schema does not match samples."
assert output_keys.issubset(
sample.keys()
), "Output schema does not match samples."
# Build index mappings
self._patient_to_index = {}
self._record_to_index = {}
for i, sample in enumerate(samples):
patient_id = sample.get("patient_id")
if patient_id is not None:
self._patient_to_index.setdefault(patient_id, []).append(i)
record_id = sample.get("record_id", sample.get("visit_id"))
if record_id is not None:
self._record_to_index.setdefault(record_id, []).append(i)
# Fit processors if they were not provided
if not self._input_processors:
for key, spec in self.input_schema.items():
processor = self._get_processor_instance(spec)
processor.fit(samples, key)
self._input_processors[key] = processor
if not self._output_processors:
for key, spec in self.output_schema.items():
processor = self._get_processor_instance(spec)
processor.fit(samples, key)
self._output_processors[key] = processor
self._fitted = True
def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]:
"""Transform a single serialized (pickled) sample using fitted processors.
Args:
sample: A mapping with a single key `"sample"` whose value is a
pickled Python dictionary (produced by `pickle.dumps`). The
pickled dictionary should mirror the schema that was used to
fit this builder.
Returns:
A Python dictionary where each key is either an input or output
feature name. Values for keys present in the corresponding fitted
processors have been processed through their FeatureProcessor and
are returned as the output of that processor. Keys not covered by
the input/output processors are returned unchanged.
"""
if not self._fitted:
raise RuntimeError("SampleBuilder.fit must be called before transform().")
transformed: Dict[str, Any] = {}
for key, value in pickle.loads(sample["sample"]).items():
if key in self._input_processors:
# Skip ignored features
if isinstance(self._input_processors[key], IgnoreProcessor):
continue
transformed[key] = self._input_processors[key].process(value)
elif key in self._output_processors:
# Skip ignored features
if isinstance(self._output_processors[key], IgnoreProcessor):
continue
transformed[key] = self._output_processors[key].process(value)
else:
transformed[key] = value
return transformed
def save(self, path: str) -> None:
"""Save fitted metadata to the given path as a pickled file.
Args:
path: Location where the builder will write a pickled metadata file
(commonly named `schema.pkl`). The saved metadata contains
the fitted input/output schemas, processors, and index
mappings. This file is read by `SampleDataset` during
construction.
"""
if not self._fitted:
raise RuntimeError("SampleBuilder.fit must be called before save().")
metadata = {
"input_schema": self.input_schema,
"output_schema": self.output_schema,
"input_processors": self._input_processors,
"output_processors": self._output_processors,
"patient_to_index": self._patient_to_index,
"record_to_index": self._record_to_index,
}
with open(path, "wb") as f:
pickle.dump(metadata, f)
@staticmethod
def load(path: str) -> "SampleBuilder":
"""Load a SampleBuilder from a pickled metadata file.
Args:
path: Location of the pickled metadata file (commonly named `schema.pkl`).
Returns:
A SampleBuilder instance with loaded metadata.
"""
with open(path, "rb") as f:
metadata = pickle.load(f)
builder = SampleBuilder(
input_schema=metadata["input_schema"],
output_schema=metadata["output_schema"],
)
builder._input_processors = metadata["input_processors"]
builder._output_processors = metadata["output_processors"]
builder._patient_to_index = metadata["patient_to_index"]
builder._record_to_index = metadata["record_to_index"]
builder._fitted = True
return builder
[docs]class SampleDataset(litdata.StreamingDataset):
"""A streaming dataset that loads sample metadata and processors from disk.
SampleDataset expects the `path` directory to contain a `schema.pkl`
file created by a `SampleBuilder.save(...)` call. The `schema.pkl` must
include the fitted `input_schema`, `output_schema`, `input_processors`,
`output_processors`, `patient_to_index` and `record_to_index` mappings.
Attributes:
input_schema: The configuration used to instantiate processors for
input features (string aliases or processor specs).
output_schema: The configuration used to instantiate processors for
output features.
input_processors: A mapping of input feature names to fitted
FeatureProcessor instances.
output_processors: A mapping of output feature names to fitted
FeatureProcessor instances.
patient_to_index: Dictionary mapping patient IDs to the list of
sample indices associated with that patient.
record_to_index: Dictionary mapping record/visit IDs to the list of
sample indices associated with that record.
dataset_name: Optional human friendly dataset name.
task_name: Optional human friendly task name.
"""
def __init__(
self,
path: str,
dataset_name: Optional[str] = None,
task_name: Optional[str] = None,
**kwargs,
) -> None:
"""Initialize a SampleDataset pointing at a directory created by SampleBuilder.
Args:
path: Path to a directory containing a `schema.pkl` produced by
`SampleBuilder.save` and associated pickled sample files.
dataset_name: Optional human-friendly dataset name.
task_name: Optional human-friendly task name.
**kwargs: Extra keyword arguments forwarded to
`litdata.StreamingDataset` (such as streaming options).
"""
super().__init__(path, **kwargs)
self.path = path
self.dataset_name = "" if dataset_name is None else dataset_name
self.task_name = "" if task_name is None else task_name
with open(f"{path}/schema.pkl", "rb") as f:
metadata = pickle.load(f)
self.input_schema: dict[str, Any] = metadata["input_schema"]
self.output_schema: dict[str, Any] = metadata["output_schema"]
self.input_processors: dict[str, FeatureProcessor] = metadata["input_processors"]
self.output_processors: dict[str, FeatureProcessor]= metadata["output_processors"]
self._remove_ignored_processors()
self.patient_to_index = metadata["patient_to_index"]
self.record_to_index = metadata["record_to_index"]
def _remove_ignored_processors(self):
"""Remove any processors that are IgnoreProcessor instances."""
for key in [
key
for key, proc in self.input_processors.items()
if isinstance(proc, IgnoreProcessor)
]:
del self.input_processors[key]
del self.input_schema[key]
for key in [
key
for key, proc in self.output_processors.items()
if isinstance(proc, IgnoreProcessor)
]:
del self.output_processors[key]
del self.output_schema[key]
def __str__(self) -> str:
"""Returns a string representation of the dataset.
Returns:
str: A string with the dataset and task names.
"""
return f"Sample dataset {self.dataset_name} {self.task_name}"
[docs] def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset":
"""Create a StreamingDataset restricted to the provided indices."""
new_dataset = deepcopy_dataset(self)
if len(new_dataset.subsampled_files) != len(new_dataset.region_of_interest):
raise ValueError(
"The provided dataset has mismatched subsampled_files and region_of_interest lengths."
)
dataset_length = sum(
end - start for start, end in new_dataset.region_of_interest
)
if isinstance(indices, slice):
indices = range(*indices.indices(dataset_length))
if any(idx < 0 or idx >= dataset_length for idx in indices):
raise ValueError(
f"Subset indices must be in [0, {dataset_length - 1}] for the provided dataset."
)
# Build chunk boundaries so we can translate global indices into
# chunk-local (start, end) pairs that litdata understands.
chunk_starts: List[int] = []
chunk_boundaries: List[Tuple[str, int, int, int, int]] = []
cursor = 0
for filename, (roi_start, roi_end) in zip(
new_dataset.subsampled_files, new_dataset.region_of_interest
):
chunk_len = roi_end - roi_start
if chunk_len <= 0:
continue
chunk_starts.append(cursor)
chunk_boundaries.append(
(filename, roi_start, roi_end, cursor, cursor + chunk_len)
)
cursor += chunk_len
new_subsampled_files: List[str] = []
new_roi: List[Tuple[int, int]] = []
prev_chunk_idx: Optional[int] = None
for idx in indices:
chunk_idx = bisect_right(chunk_starts, idx) - 1
if chunk_idx < 0 or idx >= chunk_boundaries[chunk_idx][4]:
raise ValueError(f"Index {idx} is out of bounds for the dataset.")
filename, roi_start, _, global_start, _ = chunk_boundaries[chunk_idx]
offset_in_chunk = roi_start + (idx - global_start)
if (
new_roi
and prev_chunk_idx == chunk_idx
and offset_in_chunk == new_roi[-1][1]
):
new_roi[-1] = (new_roi[-1][0], new_roi[-1][1] + 1)
else:
new_subsampled_files.append(filename)
new_roi.append((offset_in_chunk, offset_in_chunk + 1))
prev_chunk_idx = chunk_idx
new_dataset.subsampled_files = new_subsampled_files
new_dataset.region_of_interest = new_roi
new_dataset.reset()
return new_dataset
[docs] def close(self) -> None:
"""Cleans up any temporary directories used by the dataset."""
if self.input_dir.path is not None and Path(self.input_dir.path).exists():
shutil.rmtree(self.input_dir.path)
# --------------------------------------------------------------
# Context manager support
# --------------------------------------------------------------
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
self.close()
class InMemorySampleDataset(SampleDataset):
"""A SampleDataset that loads all samples into memory for fast access.
InMemorySampleDataset extends SampleDataset by eagerly loading and
transforming all samples into memory during initialization. This allows
for fast, repeated access to samples without disk I/O, at the cost of
higher memory usage.
Note:
This class is intended for testing and debugging purposes where
dataset sizes are small enough to fit into memory.
"""
def __init__(
self,
samples: List[Dict[str, Any]],
input_schema: Dict[str, Any],
output_schema: Dict[str, Any],
dataset_name: Optional[str] = None,
task_name: Optional[str] = None,
input_processors: Optional[Dict[str, FeatureProcessor]] = None,
output_processors: Optional[Dict[str, FeatureProcessor]] = None,
) -> None:
"""Initialize an InMemorySampleDataset from in-memory samples.
This constructor fits a SampleBuilder on the provided samples,
transforms all samples into memory, and sets up the dataset attributes.
Args:
samples: A list of sample dictionaries (in-memory).
input_schema: Schema describing how input keys should be handled.
output_schema: Schema describing how output keys should be handled.
dataset_name: Optional human-friendly dataset name.
task_name: Optional human-friendly task name.
input_processors: Optional pre-fitted input processors to use instead
of creating new ones from the input_schema.
output_processors: Optional pre-fitted output processors to use
instead of creating new ones from the output_schema.
"""
builder = SampleBuilder(
input_schema=input_schema,
output_schema=output_schema,
input_processors=input_processors,
output_processors=output_processors,
)
builder.fit(samples)
self.dataset_name = "" if dataset_name is None else dataset_name
self.task_name = "" if task_name is None else task_name
self.input_schema = builder.input_schema
self.output_schema = builder.output_schema
self.input_processors = builder.input_processors
self.output_processors = builder.output_processors
self._remove_ignored_processors()
self.patient_to_index = builder.patient_to_index
self.record_to_index = builder.record_to_index
self._data = [builder.transform({"sample": pickle.dumps(s)}) for s in samples]
self._shuffle = False
def set_shuffle(self, shuffle: bool) -> None:
self._shuffle = shuffle
def __len__(self) -> int:
"""Returns the number of samples in the dataset.
Returns:
int: The total number of samples.
"""
return len(self._data)
def __getitem__(self, index: int) -> Dict[str, Any]: # type: ignore
"""Retrieve a processed sample by index.
Args:
index: The index of the sample to retrieve.
Returns:
A dictionary containing processed input and output features.
"""
return self._data[index]
def __iter__(self) -> Iterable[Dict[str, Any]]: # type: ignore
"""Returns an iterator over all samples in the dataset.
Returns:
An iterator yielding processed sample dictionaries.
"""
if self._shuffle:
shuffled_data = self._data[:]
random.shuffle(shuffled_data)
return iter(shuffled_data)
else:
return iter(self._data)
def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset:
if isinstance(indices, slice):
samples = self._data[indices]
else:
samples = [self._data[i] for i in indices]
new_dataset = copy.deepcopy(self)
new_dataset._data = samples
return new_dataset
def close(self) -> None:
pass # No temporary directories to clean up for in-memory dataset
def create_sample_dataset(
samples: List[Dict[str, Any]],
input_schema: Dict[str, Any],
output_schema: Dict[str, Any],
dataset_name: Optional[str] = None,
task_name: Optional[str] = None,
input_processors: Optional[Dict[str, FeatureProcessor]] = None,
output_processors: Optional[Dict[str, FeatureProcessor]] = None,
in_memory: bool = True,
):
"""Convenience helper to create an on-disk SampleDataset from in-memory samples.
This helper will:
- Create a temporary directory for the dataset output.
- Fit a `SampleBuilder` with the provided schemas and samples.
- Save the fitted `schema.pkl` to the temporary directory.
- Use `litdata.optimize` with `builder.transform` to write serialized
and chunked sample files into the directory.
- Return a `SampleDataset` instance pointed at the temporary directory.
Args:
samples: A list of Python dictionaries representing raw samples.
input_schema: Schema describing how input keys should be handled.
output_schema: Schema describing how output keys should be handled.
dataset_name: Optional dataset name to attach to the returned
SampleDataset instance.
task_name: Optional task name to attach to the returned SampleDataset
instance.
input_processors: Optional pre-fitted input processors to use instead
of creating new ones from the input_schema.
output_processors: Optional pre-fitted output processors to use
instead of creating new ones from the output_schema.
in_memory: If True, returns an InMemorySampleDataset instead of
a disk-backed SampleDataset.
Returns:
An instance of `SampleDataset` loaded from the temporary directory
containing the optimized, chunked samples and `schema.pkl` metadata.
"""
if in_memory:
return InMemorySampleDataset(
samples=samples,
input_schema=input_schema,
output_schema=output_schema,
dataset_name=dataset_name,
task_name=task_name,
input_processors=input_processors,
output_processors=output_processors,
)
else:
path = Path(tempfile.mkdtemp())
builder = SampleBuilder(
input_schema=input_schema, # type: ignore
output_schema=output_schema, # type: ignore
input_processors=input_processors,
output_processors=output_processors,
)
builder.fit(samples)
builder.save(str(path / "schema.pkl"))
litdata.optimize(
fn=builder.transform,
inputs=[{"sample": pickle.dumps(x)} for x in samples],
output_dir=str(path),
chunk_bytes="64MB",
num_workers=0,
)
return SampleDataset(
path=str(path),
dataset_name=dataset_name,
task_name=task_name,
)