import logging
import os
import pickle
from abc import ABC
from pathlib import Path
from typing import Dict, Iterator, Iterable, List, Optional, Any, Callable
import functools
import operator
from urllib.parse import urlparse, urlunparse
from urllib.request import urlretrieve
import json
import uuid
import platformdirs
import multiprocessing
import multiprocessing.queues
import shutil
from filelock import FileLock
import litdata
from litdata.streaming.item_loader import ParquetLoader
from litdata.processing.data_processor import in_notebook
from litdata.streaming.writer import BinaryWriter
import pyarrow as pa
import pyarrow.csv as pv
import pyarrow.parquet as pq
import pandas as pd
import polars as pl
import requests
from tqdm import tqdm
import dask.dataframe as dd
from dask.distributed import (
Client as DaskClient,
LocalCluster as DaskCluster,
progress as dask_progress,
)
import narwhals as nw
import itertools
import numpy as np
import more_itertools
from ..data import Patient
from ..tasks import BaseTask
from ..processors.base_processor import FeatureProcessor
from .configs import load_yaml_config
from .sample_dataset import SampleDataset, SampleBuilder
from ..utils import set_env
# Set logging level for distributed to ERROR to reduce verbosity
logging.getLogger("distributed").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
# Remove LitData version check to avoid unnecessary warnings
os.environ["LITDATA_DISABLE_VERSION_CHECK"] = "1"
def is_url(path: str) -> bool:
"""URL detection."""
result = urlparse(path)
# Both scheme and netloc must be present for a valid URL
return all([result.scheme, result.netloc])
def clean_path(path: str) -> str:
"""Clean a path string."""
if is_url(path):
parsed = urlparse(path)
cleaned_path = os.path.normpath(parsed.path)
# Rebuild the full URL
return urlunparse(parsed._replace(path=cleaned_path))
else:
# It's a local path — resolve and normalize
return str(Path(path).expanduser().resolve())
def path_exists(path: str) -> bool:
"""
Check if a path exists.
If the path is a URL, it will send a HEAD request.
If the path is a local file, it will use the Path.exists().
"""
if is_url(path):
try:
response = requests.head(path, timeout=5)
return response.status_code == 200
except requests.RequestException:
return False
else:
return Path(path).exists()
def _csv_tsv_gz_path(path: str) -> str:
"""
Get the path to the file, trying the original path first, then the alternative path
by switching between .csv.gz, .csv, .tsv.gz, and .tsv extensions.
Args:
path (str): Original file path.
Returns:
str: The file path that exists.
Raises:
FileNotFoundError: If neither the original nor the alternative path exists.
ValueError: If the path does not have an expected extension.
"""
if path_exists(path):
return path
if path.endswith(".csv.gz"):
alt_path = path[:-3] # Remove .gz -> try .csv
elif path.endswith(".csv"):
alt_path = f"{path}.gz" # Add .gz -> try .csv.gz
elif path.endswith(".tsv.gz"):
alt_path = path[:-3] # Remove .gz -> try .tsv
elif path.endswith(".tsv"):
alt_path = f"{path}.gz" # Add .gz -> try .tsv.gz
else:
raise ValueError(f"Path does not have expected extension: {path}")
if path_exists(alt_path):
return alt_path
raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}")
def _litdata_merge(cache_dir: Path) -> None:
"""
Merges LitData binary writer index files in the given cache directory.
Args:
cache_dir (Path): The cache directory containing LitData binary writer files.
"""
from litdata.streaming.writer import _INDEX_FILENAME
files = os.listdir(cache_dir)
# Return if the index already exists
if _INDEX_FILENAME in files:
return
index_files = [f for f in files if f.endswith(_INDEX_FILENAME)]
# Return if there are no index files to merge
if len(index_files) == 0:
raise ValueError(
"There are zero samples in the dataset, please check the task and processors."
)
BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge(
num_workers=len(index_files)
)
class _ProgressContext:
def __init__(
self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs
):
"""
:param queue: An existing queue (e.g., from multiprocessing). If provided,
this class acts as a passthrough.
:param total: Total items for the progress bar (only used if queue is None).
:param kwargs: Extra arguments for tqdm (e.g., desc="Processing").
"""
self.queue = queue
self.total = total
self.kwargs = kwargs
self.progress = None
def put(self, n):
if self.progress:
self.progress.update(n)
def __enter__(self):
if self.queue:
return self.queue
self.progress = tqdm(total=self.total, **self.kwargs)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.progress:
self.progress.close()
_task_transform_progress: multiprocessing.queues.Queue | None = None
def _task_transform_init(queue: multiprocessing.queues.Queue) -> None:
"""
Initializer for worker processes to set up a global queue.
Args:
queue (multiprocessing.queues.Queue): The queue for progress tracking.
"""
global _task_transform_progress
_task_transform_progress = queue
def _task_transform_fn(
args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path],
) -> None:
"""
Worker function to apply task transformation on a chunk of patients.
Args:
args (tuple): A tuple containing:
worker_id (int): The ID of the worker.
task (BaseTask): The task to apply.
patient_ids (Iterable[str]): The patient IDs to process.
global_event_df (pl.LazyFrame): The global event dataframe.
output_dir (Path): The output directory to save results.
"""
BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%.
worker_id, task, patient_ids, global_event_df, output_dir = args
total_patients = len(list(patient_ids))
logger.info(
f"Worker {worker_id} started processing {total_patients} patients. (Polars threads: {pl.thread_pool_size()})"
)
with (
set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)),
_ProgressContext(_task_transform_progress, total=total_patients) as progress,
):
writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB")
write_index = 0
batches = itertools.batched(patient_ids, BATCH_SIZE)
for batch in batches:
complete = 0
patients = (
global_event_df.filter(pl.col("patient_id").is_in(batch))
.collect(engine="streaming")
.partition_by("patient_id", as_dict=True)
)
for patient_id, patient_df in patients.items():
patient_id = patient_id[0] # Extract string from single-element list
patient = Patient(patient_id=patient_id, data_source=patient_df)
for sample in task(patient):
writer.add_item(write_index, {"sample": pickle.dumps(sample)})
write_index += 1
complete += 1
progress.put(complete)
writer.done()
logger.info(f"Worker {worker_id} finished processing patients.")
_proc_transform_progress: multiprocessing.queues.Queue | None = None
def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None:
"""
Initializer for worker processes to set up a global queue.
Args:
queue (multiprocessing.queues.Queue): The queue for progress tracking.
"""
global _proc_transform_progress
_proc_transform_progress = queue
def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None:
"""
Worker function to apply processors on a chunk of samples.
Args:
args (tuple): A tuple containing:
worker_id (int): The ID of the worker.
task_df (Path): The path to the task dataframe.
start_idx (int): The start index of samples to process.
end_idx (int): The end index of samples to process.
output_dir (Path): The output directory to save results.
"""
BATCH_SIZE = 128
worker_id, task_df, start_idx, end_idx, output_dir = args
total_samples = end_idx - start_idx
logger.info(
f"Worker {worker_id} started processing {total_samples} samples. ({start_idx} to {end_idx})"
)
with (
set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)),
_ProgressContext(_proc_transform_progress, total=total_samples) as progress,
):
writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB")
dataset = litdata.StreamingDataset(str(task_df))
builder = SampleBuilder.load(f"{output_dir}/schema.pkl")
complete = 0
write_index = 0
for i in range(start_idx, end_idx):
transformed: Dict[str, Any] = builder.transform(dataset[i])
writer.add_item(write_index, transformed)
write_index += 1
complete += 1
if complete >= BATCH_SIZE:
progress.put(complete)
complete = 0
if complete > 0:
progress.put(complete)
writer.done()
logger.info(f"Worker {worker_id} finished processing samples.")
[docs]class BaseDataset(ABC):
"""Abstract base class for all PyHealth datasets.
Attributes:
root (Path): The root directory where dataset files are stored.
tables (List[str]): List of table names to load.
dataset_name (str): Name of the dataset.
config (dict): Configuration loaded from a YAML file.
global_event_df (pl.LazyFrame): The global event data frame.
dev (bool): Whether to enable dev mode (limit to 1000 patients).
"""
def __init__(
self,
root: str,
tables: List[str],
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
cache_dir: str | Path | None = None,
num_workers: int = 1,
dev: bool = False,
):
"""Initializes the BaseDataset.
Args:
root (str): The root directory where dataset files are stored.
tables (List[str]): List of table names to load.
dataset_name (Optional[str]): Name of the dataset. Defaults to class name.
config_path (Optional[str]): Path to the configuration YAML file.
cache_dir (Optional[str | Path]): Directory for caching processed data.
Behavior depends on the type passed:
- **None** (default): Auto-generates a cache path under the default
pyhealth cache directory.
- **str** or **Path**: Used as the root cache directory path. A UUID
is appended to the provided path to capture dataset configuration.
num_workers (int): Number of worker processes for parallel operations.
dev (bool): Whether to run in dev mode (limits to 1000 patients).
"""
if len(set(tables)) != len(tables):
logger.warning("Duplicate table names in tables list. Removing duplicates.")
tables = list(set(tables))
self.root = root
self.tables = tables
self.dataset_name = dataset_name or self.__class__.__name__
self.num_workers = num_workers
self.dev = dev
self.config = load_yaml_config(config_path) if config_path else None
logger.info(
f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})"
)
# Cached attributes
self.cache_dir = self._init_cache_dir(cache_dir)
self._global_event_df = None
self._unique_patient_ids = None
def _init_cache_dir(self, cache_dir: str | Path | None) -> Path:
"""Returns the cache directory path.
The cache directory is determined by the type of ``cache_dir`` passed
to ``__init__``:
- **None**: Auto-generated under default pyhealth cache directory.
- **str** or **Path: Used as the root cache directory path. A UUID
is appended to the provided path to capture dataset configuration.
The cache structure within the directory is::
{dataset_uuid}/ # Cache files for this dataset configuration
tmp/ # Temporary files during processing
global_event_df.parquet/ # Cached global event dataframe
tasks/ # Cached task-specific data
{task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args
task_df.ld/ # Intermediate task dataframe based on schema
samples_{proc_uuid}.ld/ # Final processed samples after applying processors
Returns:
Path: The resolved cache directory path.
"""
id_str = json.dumps(
{
"root": str(self.root),
"tables": sorted(self.tables),
"dataset_name": self.dataset_name,
"dev": self.dev,
},
sort_keys=True,
)
id = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_str))
if cache_dir is None:
cache_dir = Path(platformdirs.user_cache_dir(appname="pyhealth")) / id
cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}")
else:
# Ensure separate cache directories for different table configurations by appending a UUID suffix
cache_dir = Path(cache_dir) / id
cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Using provided cache_dir: {cache_dir}")
return Path(cache_dir)
[docs] def create_tmpdir(self) -> Path:
"""Creates and returns a new temporary directory within the cache.
Returns:
Path: The path to the new temporary directory.
"""
tmp_dir = self.cache_dir / "tmp" / str(uuid.uuid4())
tmp_dir.mkdir(parents=True, exist_ok=True)
return tmp_dir
[docs] def clean_tmpdir(self) -> None:
"""Cleans up the temporary directory within the cache."""
tmp_dir = self.cache_dir / "tmp"
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
def _scan_csv_tsv_gz(self, source_path: str) -> dd.DataFrame:
"""Scans a CSV/TSV file (possibly gzipped) and returns a Dask DataFrame.
If the cached Parquet file does not exist, it converts the source CSV/TSV file
to Parquet and saves it to the cache.
Args:
source_path (str): The source CSV/TSV file path.
Returns:
dd.DataFrame: The Dask DataFrame loaded from the cached Parquet file.
Raises:
FileNotFoundError: If source_path is None and the cached Parquet file does not exist;
or if neither the original nor the alternative path of source_path exists.
ValueError: If the path does not have an expected extension.
"""
# Ensure the tables cache directory exists
ret_path = self.create_tmpdir() / "table.parquet"
if not ret_path.exists():
source_path = _csv_tsv_gz_path(source_path)
if is_url(source_path):
local_filename = os.path.basename(source_path)
local_path = self.create_tmpdir() / local_filename
if not local_path.exists():
logger.info(f"Downloading {source_path} to {local_path}")
urlretrieve(source_path, local_path)
source_path = str(local_path)
# Determine delimiter based on file extension
delimiter = (
"\t"
if source_path.endswith(".tsv") or source_path.endswith(".tsv.gz")
else ","
)
# Always infer schema as string to avoid incorrect type inference
# Enable newlines_in_values for clinical notes with multi-line text
schema_reader = pv.open_csv(
source_path,
read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB
parse_options=pv.ParseOptions(
delimiter=delimiter, newlines_in_values=True
),
)
schema = pa.schema(
[pa.field(name, pa.string()) for name in schema_reader.schema.names]
)
# Convert CSV/TSV to Parquet
csv_reader = pv.open_csv(
source_path,
read_options=pv.ReadOptions(block_size=1 << 26), # 64 MB
parse_options=pv.ParseOptions(
delimiter=delimiter, newlines_in_values=True
),
convert_options=pv.ConvertOptions(column_types=schema),
)
with pq.ParquetWriter(ret_path, csv_reader.schema) as writer:
for batch in csv_reader:
writer.write_batch(batch)
df: dd.DataFrame = dd.read_parquet(
ret_path,
split_row_groups=True, # type: ignore
blocksize="64MB",
)
return df.replace("", pd.NA) # Replace empty strings with NaN
def _event_transform(self, output_dir: Path) -> None:
compute_ok = False
try:
df = self.load_data()
with DaskCluster(
n_workers=self.num_workers,
threads_per_worker=1,
processes=not in_notebook(),
# Use cache_dir for Dask's scratch space to avoid filling up /tmp or home directory
local_directory=str(self.create_tmpdir()),
) as cluster:
with DaskClient(cluster) as client:
if self.dev:
logger.info("Dev mode enabled: limiting to 1000 patients")
patients = df["patient_id"].unique().head(1000).tolist()
filter = df["patient_id"].isin(patients)
df = df[filter]
logger.info(f"Caching event dataframe to {output_dir}...")
collection = df.sort_values("patient_id").to_parquet(
output_dir,
write_index=False,
compute=False,
)
handle = client.compute(collection)
dask_progress(handle)
handle.result() # type: ignore
compute_ok = True # Data is fully written to disk
except TimeoutError:
if compute_ok:
# Cluster shutdown timed out after successful compute — data is intact
logger.warning(
"Dask cluster shutdown timed out, but data was written successfully. Continuing."
)
else:
if output_dir.exists():
logger.error(
f"Error during caching, removing incomplete file {output_dir}"
)
shutil.rmtree(output_dir)
raise
except Exception as e:
if output_dir.exists():
logger.error(
f"Error during caching, removing incomplete file {output_dir}"
)
shutil.rmtree(output_dir)
raise e
finally:
self.clean_tmpdir()
@property
def global_event_df(self) -> pl.LazyFrame:
"""Returns the path to the cached event dataframe.
Returns:
Path: The path to the cached event dataframe.
"""
self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore
if self._global_event_df is None:
ret_path = self.cache_dir / "global_event_df.parquet"
cache_valid = ret_path.is_dir() and any(ret_path.glob("*.parquet"))
if not cache_valid:
if ret_path.exists():
logger.warning(
f"Incomplete parquet cache at {ret_path} (directory exists but contains no parquet files). Removing and rebuilding."
)
shutil.rmtree(ret_path)
logger.info(f"No cached event dataframe found. Creating: {ret_path}")
self._event_transform(ret_path)
else:
logger.info(f"Found cached event dataframe: {ret_path}")
self._global_event_df = ret_path
return pl.scan_parquet(
self._global_event_df,
low_memory=True,
)
[docs] def load_data(self) -> dd.DataFrame:
"""Loads data from the specified tables.
Returns:
dd.DataFrame: A concatenated lazy frame of all tables.
"""
frames = [self.load_table(table.lower()) for table in self.tables]
return dd.concat(frames, axis=0, join="outer")
[docs] def load_table(self, table_name: str) -> dd.DataFrame:
"""Loads a table and processes joins if specified.
Args:
table_name (str): The name of the table to load.
Returns:
dd.DataFrame: The processed Dask dataframe for the table.
Raises:
ValueError: If the table is not found in the config.
FileNotFoundError: If the CSV file for the table or join is not found.
"""
assert self.config is not None, "Config must be provided to load tables"
if table_name not in self.config.tables:
raise ValueError(f"Table {table_name} not found in config")
table_cfg = self.config.tables[table_name]
csv_path = f"{self.root}/{table_cfg.file_path}"
csv_path = clean_path(csv_path)
logger.info(f"Scanning table: {table_name} from {csv_path}")
df = self._scan_csv_tsv_gz(csv_path)
# Convert column names to lowercase before calling preprocess_func
df = df.rename(columns=str.lower)
# Check if there is a preprocessing function for this table
preprocess_func: Optional[Callable[[nw.LazyFrame], nw.LazyFrame]]
preprocess_func = getattr(self, f"preprocess_{table_name}", None)
if preprocess_func is not None:
logger.info(
f"Preprocessing table: {table_name} with {preprocess_func.__name__}"
)
df = preprocess_func(nw.from_native(df)).to_native() # type: ignore
# Handle joins
for join_cfg in table_cfg.join:
other_csv_path = f"{self.root}/{join_cfg.file_path}"
other_csv_path = clean_path(other_csv_path)
logger.info(f"Joining with table: {other_csv_path}")
join_df = self._scan_csv_tsv_gz(other_csv_path)
join_df = join_df.rename(columns=str.lower)
join_key = join_cfg.on
columns = join_cfg.columns
how = join_cfg.how
df: dd.DataFrame = df.merge(
join_df[[join_key] + columns], on=join_key, how=how
)
patient_id_col = table_cfg.patient_id
timestamp_col = table_cfg.timestamp
timestamp_format = table_cfg.timestamp_format
attribute_cols = table_cfg.attributes
# Timestamp expression
# .astype(str) will convert `pd.NA` to "<NA>", which will raise error in to_datetime
# use .astype("string") instead, which keeps `pd.NA` as is.
if timestamp_col:
if isinstance(timestamp_col, list):
# Concatenate all timestamp parts in order with no separator
timestamp_series: dd.Series = functools.reduce(
operator.add, (df[col].astype("string") for col in timestamp_col)
)
else:
timestamp_series: dd.Series = df[timestamp_col].astype("string")
timestamp_series: dd.Series = dd.to_datetime(
timestamp_series,
format=timestamp_format,
errors="raise",
)
df: dd.DataFrame = df.assign(
timestamp=timestamp_series.astype("datetime64[ms]")
)
else:
df: dd.DataFrame = df.assign(timestamp=pd.NaT)
# If patient_id_col is None, use row index as patient_id
if patient_id_col:
df: dd.DataFrame = df.assign(patient_id=df[patient_id_col].astype("string"))
else:
df: dd.DataFrame = df.reset_index(drop=True)
df: dd.DataFrame = df.assign(patient_id=df.index.astype("string"))
df: dd.DataFrame = df.assign(event_type=table_name)
rename_attr = {attr.lower(): f"{table_name}/{attr}" for attr in attribute_cols}
df: dd.DataFrame = df.rename(columns=rename_attr)
attr_cols = [rename_attr[attr.lower()] for attr in attribute_cols]
final_cols = ["patient_id", "event_type", "timestamp"] + attr_cols
event_frame = df[final_cols]
return event_frame
@property
def unique_patient_ids(self) -> List[str]:
"""Returns a list of unique patient IDs.
Returns:
List[str]: List of unique patient IDs.
"""
if self._unique_patient_ids is None:
self._unique_patient_ids = (
self.global_event_df.select("patient_id")
.unique()
.collect(engine="streaming")
.to_series()
.to_list()
)
logger.info(f"Found {len(self._unique_patient_ids)} unique patient IDs")
return self._unique_patient_ids
[docs] def get_patient(self, patient_id: str) -> Patient:
"""Retrieves a Patient object for the given patient ID.
Args:
patient_id (str): The ID of the patient to retrieve.
Returns:
Patient: The Patient object for the given ID.
Raises:
AssertionError: If the patient ID is not found in the dataset.
"""
assert (
patient_id in self.unique_patient_ids
), f"Patient {patient_id} not found in dataset"
data_source = self.global_event_df.filter(
pl.col("patient_id") == patient_id
).collect(engine="streaming")
return Patient(patient_id=patient_id, data_source=data_source)
[docs] def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]:
"""Yields Patient objects for each unique patient in the dataset.
Yields:
Iterator[Patient]: An iterator over Patient objects.
"""
if df is None:
df = self.global_event_df
patient_ids = (
df.select("patient_id")
.unique(maintain_order=True)
.collect(engine="streaming")
.to_series()
)
for patient_id in patient_ids:
patient_df = df.filter(pl.col("patient_id") == patient_id).collect(
engine="streaming"
)
yield Patient(patient_id=patient_id, data_source=patient_df)
[docs] def stats(self) -> None:
"""Prints statistics about the dataset."""
stats = self.global_event_df.select(
pl.len().alias("n_events"),
pl.col("patient_id").n_unique().alias("n_patients"),
).collect(engine="streaming")
print(f"Dataset: {self.dataset_name}")
print(f"Dev mode: {self.dev}")
print(f"Number of patients: {stats['n_patients'][0]}")
print(f"Number of events: {stats['n_events'][0]}")
@property
def default_task(self) -> Optional[BaseTask]:
"""Returns the default task for the dataset.
Returns:
Optional[BaseTask]: The default task, if any.
"""
return None
def _task_transform(
self, task: BaseTask, output_dir: Path, num_workers: int
) -> None:
self._main_guard(self._task_transform.__name__)
logger.info(
f"Applying task transformations on data with {num_workers} workers..."
)
global_event_df = task.pre_filter(self.global_event_df)
patient_ids = (
global_event_df.select("patient_id")
.unique()
.collect(engine="streaming")
.to_series()
# .sort can reduce runtime by 5%.
.sort()
)
if in_notebook():
logger.info(
"Detected Jupyter notebook environment, setting num_workers to 1"
)
num_workers = 1
num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers
# This ensures worker's polars threads are limited to avoid oversubscription,
# which can lead to additional 75% speedup when num_workers is large.
threads_per_worker = max(1, (os.cpu_count() or 1) // num_workers)
try:
with set_env(
POLARS_MAX_THREADS=str(threads_per_worker),
DATA_OPTIMIZER_NUM_WORKERS=str(num_workers),
):
if num_workers == 1:
logger.info("Single worker mode, processing sequentially")
_task_transform_fn(
(0, task, patient_ids, global_event_df, output_dir)
)
_litdata_merge(output_dir)
return
# spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary
ctx = multiprocessing.get_context("spawn")
queue = ctx.Queue()
args_list = [
(
worker_id,
task,
pids,
global_event_df,
output_dir,
)
for worker_id, pids in enumerate(
itertools.batched(
patient_ids, len(patient_ids) // num_workers + 1
)
)
]
with ctx.Pool(
processes=num_workers,
initializer=_task_transform_init,
initargs=(queue,),
) as pool:
result = pool.map_async(_task_transform_fn, args_list) # type: ignore
with tqdm(total=len(patient_ids)) as progress:
while not result.ready():
try:
progress.update(queue.get(timeout=1))
except:
pass
# remaining items
while not queue.empty():
progress.update(queue.get())
result.get() # ensure exceptions are raised
_litdata_merge(output_dir)
logger.info(f"Task transformation completed and saved to {output_dir}")
except Exception as e:
logger.error(
f"Error during task transformation, cleaning up output directory: {output_dir}"
)
shutil.rmtree(output_dir)
raise e
def _proc_transform(
self, task_df: Path, output_dir: Path, num_workers: int
) -> None:
self._main_guard(self._proc_transform.__name__)
logger.info(f"Applying processors on data with {num_workers} workers...")
num_samples = len(litdata.StreamingDataset(str(task_df)))
if in_notebook():
logger.info(
"Detected Jupyter notebook environment, setting num_workers to 1"
)
num_workers = 1
num_workers = min(num_workers, num_samples) # Avoid spawning empty workers
try:
with set_env(DATA_OPTIMIZER_NUM_WORKERS=str(num_workers)):
if num_workers == 1:
logger.info("Single worker mode, processing sequentially")
_proc_transform_fn((0, task_df, 0, num_samples, output_dir))
_litdata_merge(output_dir)
return
ctx = multiprocessing.get_context("spawn")
queue = ctx.Queue()
linspace = more_itertools.sliding_window(
np.linspace(0, num_samples, num_workers + 1, dtype=int), 2
)
args_list = [
(
worker_id,
task_df,
start,
end,
output_dir,
)
for worker_id, (start, end) in enumerate(linspace)
]
with ctx.Pool(
processes=num_workers,
initializer=_proc_transform_init,
initargs=(queue,),
) as pool:
result = pool.map_async(_proc_transform_fn, args_list) # type: ignore
with tqdm(total=num_samples) as progress:
while not result.ready():
try:
progress.update(queue.get(timeout=1))
except:
pass
# remaining items
while not queue.empty():
progress.update(queue.get())
result.get() # ensure exceptions are raised
_litdata_merge(output_dir)
logger.info(
f"Processor transformation completed and saved to {output_dir}"
)
except Exception as e:
logger.error(f"Error during processor transformation.")
shutil.rmtree(output_dir)
raise e
finally:
self.clean_tmpdir()
[docs] def set_task(
self,
task: Optional[BaseTask] = None,
num_workers: Optional[int] = None,
input_processors: Optional[Dict[str, FeatureProcessor]] = None,
output_processors: Optional[Dict[str, FeatureProcessor]] = None,
) -> SampleDataset:
"""Processes the base dataset to generate the task-specific sample dataset.
The cache structure is as follows::
{task_name}_{task_uuid}/ # Cached data for specific task based on task name, schema, and args
task_df.ld/ # Intermediate task dataframe based on schema
samples_{proc_uuid}.ld/ # Final processed samples after applying processors
schema.pkl # Saved SampleBuilder schema
*.bin # Processed sample files
Args:
task (Optional[BaseTask]): The task to set. Uses default task if None.
num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`.
input_processors (Optional[Dict[str, FeatureProcessor]]):
Pre-fitted input processors. If provided, these will be used
instead of creating new ones from task's input_schema. Defaults to None.
output_processors (Optional[Dict[str, FeatureProcessor]]):
Pre-fitted output processors. If provided, these will be used
instead of creating new ones from task's output_schema. Defaults to None.
Returns:
SampleDataset: The generated sample dataset.
Raises:
AssertionError: If no default task is found and task is None.
"""
self._main_guard(self.set_task.__name__)
if task is None:
assert self.default_task is not None, "No default tasks found"
task = self.default_task
if num_workers is None:
num_workers = self.num_workers
logger.info(
f"Setting task {task.task_name} for {self.dataset_name} base dataset..."
)
task_params = json.dumps(
{
**vars(task),
"input_schema": task.input_schema,
"output_schema": task.output_schema,
},
sort_keys=True,
default=str,
)
cache_dir = (
self.cache_dir
/ "tasks"
/ f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}"
)
cache_dir.mkdir(parents=True, exist_ok=True)
proc_params = json.dumps(
{
"input_processors": (
{
f"{k}_{v.__class__.__name__}": vars(v)
for k, v in input_processors.items()
}
if input_processors
else None
),
"output_processors": (
{
f"{k}_{v.__class__.__name__}": vars(v)
for k, v in output_processors.items()
}
if output_processors
else None
),
},
sort_keys=True,
default=str,
)
task_df_path = Path(cache_dir) / "task_df.ld"
samples_path = (
Path(cache_dir)
/ f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld"
)
logger.info(f"Task cache paths: task_df={task_df_path}, samples={samples_path}")
task_df_path.mkdir(parents=True, exist_ok=True)
samples_path.mkdir(parents=True, exist_ok=True)
def _is_valid_litdata_cache(path: Path) -> bool:
"""Return True if index.json exists. litdata only writes index.json after
all .bin chunks are flushed, so its presence guarantees a complete cache."""
return (path / "index.json").exists()
# Fast path: cache already valid, no lock needed (reads are always safe).
# Slow path: acquire a per-cache-dir file lock so that concurrent processes
# (e.g. parallel hparam jobs) don't race to build the same litdata cache.
# The double-checked pattern inside the lock means the winner builds it
# once; all others wait, re-check, and skip.
if not _is_valid_litdata_cache(samples_path):
lock_path = Path(cache_dir) / "build.lock"
with FileLock(str(lock_path), timeout=7200):
# Re-check inside the lock — another process may have built it
# while we were waiting.
if _is_valid_litdata_cache(samples_path):
logger.info(
f"Found cached processed samples at {samples_path} (built by another process)."
)
else:
# Check if task_df cache is valid; rebuild if not
if not _is_valid_litdata_cache(task_df_path):
self._task_transform(
task,
task_df_path,
num_workers,
)
else:
logger.info(
f"Found cached task dataframe at {task_df_path}, skipping task transformation."
)
# Build processors and fit on the dataset
logger.info(f"Fitting processors on the dataset...")
dataset = litdata.StreamingDataset(
str(task_df_path),
transform=lambda x: pickle.loads(x["sample"]),
)
builder = SampleBuilder(
input_schema=task.input_schema, # type: ignore
output_schema=task.output_schema, # type: ignore
input_processors=input_processors,
output_processors=output_processors,
)
builder.fit(dataset)
builder.save(str(samples_path / "schema.pkl"))
# Apply processors and save final samples to cache_dir
logger.info(f"Processing samples and saving to {samples_path}...")
self._proc_transform(
task_df_path,
samples_path,
num_workers,
)
logger.info(f"Cached processed samples to {samples_path}")
else:
logger.info(
f"Found cached processed samples at {samples_path}, skipping processing."
)
return SampleDataset(
path=str(samples_path),
dataset_name=self.dataset_name,
task_name=task.task_name,
)
def _main_guard(self, func_name: str):
"""Warn if method is accessed from a non-main process."""
if not multiprocessing.current_process().name == "MainProcess":
logger.warning(
f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n"
+ "Consider use __name__ == '__main__' guard when using multiprocessing."
)
exit(1)