Source code for pyhealth.datasets.tuab
import os
import logging
import pandas as pd
from pathlib import Path
from typing import Optional
from .base_dataset import BaseDataset
from pyhealth.tasks import EEGAbnormalTUAB
logger = logging.getLogger(__name__)
[docs]class TUABDataset(BaseDataset):
"""Base EEG dataset for the TUH Abnormal EEG Corpus
Dataset is available at https://isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml
The TUAB dataset (or Temple University Hospital EEG Abnormal Corpus) is a collection of EEG data acquired at the Temple University Hospital.
The dataset contains both normal and abnormal EEG readings.
Files are named in the form aaaaamye_s001_t000.edf. This includes the subject identifier ("aaaaamye"), the session number ("s001") and a token number ("t000"). EEGs are split into a series of files starting with *t000.edf, *t001.edf, ...
Args:
dataset_name: name of the dataset.
root: root directory of the raw data. *You can choose to use the path to Cassette portion or the Telemetry portion.*
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.
Attributes:
task: Optional[str], name of the task (e.g., "EEG_abnormal").
Default is None.
samples: Optional[List[Dict]], a list of samples, each sample is a dict with
patient_id, record_id, and other task-specific attributes as key.
Default is None.
patient_to_index: Optional[Dict[str, List[int]]], a dict mapping patient_id to
a list of sample indices. Default is None.
visit_to_index: Optional[Dict[str, List[int]]], a dict mapping visit_id to a
list of sample indices. Default is None.
Examples:
>>> from pyhealth.datasets import TUABDataset
>>> dataset = TUABDataset(
... root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/",
... )
>>> dataset.stat()
>>> dataset.info()
"""
def __init__(
self,
root: str,
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
subset: Optional[str] = 'both',
**kwargs
) -> None:
if config_path is None:
logger.info("No config path provided, using default config")
from pathlib import Path
config_path = Path(__file__).parent / "configs" / "tuab.yaml"
self.root = root
if subset in ['train', 'eval']:
logger.info(f"Using subset: {subset}")
tables = [subset]
elif subset == 'both':
logger.info("Using both train and eval subsets")
tables = ["train", "eval"]
else:
raise ValueError("subset must be one of 'train', 'eval', or 'both'")
self.prepare_metadata()
# Determine where the CSVs are located (shared directory or cache)
root_path = Path(root)
cache_dir = Path.home() / ".cache" / "pyhealth" / "tuab"
# Check if CSVs exist in cache and not in shared location
use_cache = False
for table in tables:
shared_csv = root_path / f"tuab-{table}-pyhealth.csv"
cache_csv = cache_dir / f"tuab-{table}-pyhealth.csv"
if not shared_csv.exists() and cache_csv.exists():
use_cache = True
break
# Use cache directory as root if CSVs are there
if use_cache:
logger.info(f"Using cached metadata from {cache_dir}")
root = str(cache_dir)
super().__init__(
root=root,
tables=tables,
dataset_name=dataset_name or "tuab",
config_path=config_path,
**kwargs
)
[docs] def prepare_metadata(self) -> None:
"""Build and save processed metadata CSVs for TUAB train/eval separately.
This writes:
- <root>/tuab-train-pyhealth.csv
- <root>/tuab-eval-pyhealth.csv
Train and eval filenames look like: aaaaalkt_s001_t000.edf
- subject_id = aaaaalkt
- session_id = s001
- token_id = t000
We define record_id as session_id + token_id.
The label is derived from the directory:
- abnormal -> 1
- normal -> 0
"""
root = Path(self.root)
cache_dir = Path.home() / ".cache" / "pyhealth" / "tuab"
for split in ("train", "eval"):
# Check if metadata exists in either shared location or cache
shared_csv = root / f"tuab-{split}-pyhealth.csv"
cache_csv = cache_dir / f"tuab-{split}-pyhealth.csv"
if shared_csv.exists() or cache_csv.exists():
continue
rows: list[dict] = []
for label in ("normal", "abnormal"):
edf_dir = root / split / label/ "01_tcp_ar"
if not edf_dir.is_dir():
logger.warning("EDF directory not found: %s", edf_dir)
continue
for edf_path in sorted(edf_dir.glob("*.edf")):
stem = edf_path.stem
parts = stem.split("_")
if len(parts) != 3:
logger.warning("Invalid filename format: %s", edf_path)
continue
subject_id = parts[0]
session_id = parts[1]
token_id = parts[2]
record_id = f'{session_id}_{token_id}'
rows.append(
{
"patient_id": subject_id,
"record_id": record_id,
"signal_file": str(edf_path),
"label": label,
}
)
if not rows:
continue
# Setup cache directory as fallback for metadata CSVs
cache_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(rows)
df.sort_values(
["patient_id", "record_id"],
inplace=True,
na_position="last",
)
df.reset_index(drop=True, inplace=True)
# Try shared location first, fall back to cache if no write permission
csv_shared = root / f"tuab-{split}-pyhealth.csv"
csv_cache = cache_dir / f"tuab-{split}-pyhealth.csv"
try:
csv_shared.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(csv_shared, index=False)
logger.info(f"Wrote {split} metadata to {csv_shared}")
except (PermissionError, OSError):
df.to_csv(csv_cache, index=False)
logger.info(f"Wrote {split} metadata to cache: {csv_cache}")
@property
def default_task(self) -> EEGAbnormalTUAB:
"""Returns the default task for the TUAB dataset: EEGAbnormalTUAB.
Returns:
EEGAbnormalTUAB: The default task instance.
"""
return EEGAbnormalTUAB()