Source code for pyhealth.datasets.tuab

import os

import numpy as np

from pyhealth.datasets import BaseSignalDataset


[docs]class TUABDataset(BaseSignalDataset): """Base EEG dataset for the TUH Abnormal EEG Corpus Dataset is available at https://isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml The TUAB dataset (or Temple University Hospital EEG Abnormal Corpus) is a collection of EEG data acquired at the Temple University Hospital. The dataset contains both normal and abnormal EEG readings. Files are named in the form aaaaamye_s001_t000.edf. This includes the subject identifier ("aaaaamye"), the session number ("s001") and a token number ("t000"). EEGs are split into a series of files starting with *t000.edf, *t001.edf, ... Args: dataset_name: name of the dataset. root: root directory of the raw data. *You can choose to use the path to Cassette portion or the Telemetry portion.* dev: whether to enable dev mode (only use a small subset of the data). Default is False. refresh_cache: whether to refresh the cache; if true, the dataset will be processed from scratch and the cache will be updated. Default is False. Attributes: task: Optional[str], name of the task (e.g., "EEG_abnormal"). Default is None. samples: Optional[List[Dict]], a list of samples, each sample is a dict with patient_id, record_id, and other task-specific attributes as key. Default is None. patient_to_index: Optional[Dict[str, List[int]]], a dict mapping patient_id to a list of sample indices. Default is None. visit_to_index: Optional[Dict[str, List[int]]], a dict mapping visit_id to a list of sample indices. Default is None. Examples: >>> from pyhealth.datasets import TUABDataset >>> dataset = TUABDataset( ... root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", ... ) >>> dataset.stat() >>> dataset.info() """
[docs] def process_EEG_data(self): # create a map for data sets for latter mapping patients data_map = { "train/abnormal": "0", "train/normal": "1", "eval/abnormal": "2", "eval/normal": "3", } data_map_reverse = { "0": "train/abnormal", "1": "train/normal", "2": "eval/abnormal", "3": "eval/normal", } # get all file names all_files = {} train_abnormal_files = os.listdir(os.path.join(self.root, "train/abnormal/01_tcp_ar")) all_files["train/abnormal"] = train_abnormal_files train_normal_files = os.listdir(os.path.join(self.root, "train/normal/01_tcp_ar")) all_files["train/normal"] = train_normal_files eval_abnormal_files = os.listdir(os.path.join(self.root, "eval/abnormal/01_tcp_ar")) all_files["eval/abnormal"] = eval_abnormal_files eval_normal_files = os.listdir(os.path.join(self.root, "eval/normal/01_tcp_ar")) all_files["eval/normal"] = eval_normal_files # get all patient ids patient_ids = [] for field, sub_data in all_files.items(): patient_ids.extend(["{}_{}".format(data_map[field], data.split("_")[0]) for data in sub_data]) patient_ids = list(set(patient_ids)) if self.dev: patient_ids = patient_ids[:20] # get patient to record maps # - key: pid: # - value: [{"load_from_path": None, "patient_id": None, "signal_file": None, "label_file": None, "save_to_path": None}, ...] patients = { pid: [] for pid in patient_ids } for pid in patient_ids: data_field = data_map_reverse[pid.split("_")[0]] patient_visits = [file for file in all_files[data_field] if file.split("_")[0] == pid.split("_")[1]] for visit in patient_visits: patients[pid].append({ "load_from_path": os.path.join(self.root, data_field, "01_tcp_ar"), "patient_id": pid, "visit_id": visit.strip(".edf").strip(pid.split("_")[1])[1:], "signal_file": visit, "label_file": visit, "save_to_path": self.filepath, }) return patients
if __name__ == "__main__": dataset = TUABDataset( root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", dev=True, refresh_cache=True, ) dataset.stat() dataset.info() print(list(dataset.patients.items())[0])