Source code for pyhealth.tasks.temple_university_EEG_tasks

import os
import pickle
import pkg_resources
import mne
import pandas as pd
import numpy as np


[docs]def EEG_isAbnormal_fn(record): """Processes a single patient for the abnormal EEG detection task on TUAB. Abnormal EEG detection aims at determining whether a EEG is abnormal. Args: record: a singleton list of one subject from the TUABDataset. The (single) record is a dictionary with the following keys: load_from_path, patient_id, visit_id, signal_file, label_file, save_to_path Returns: samples: a list of samples, each sample is a dict with patient_id, visit_id, record_id, and epoch_path (the path to the saved epoch {"signal": signal, "label": label} as key. Note that we define the task as a binary classification task. Examples: >>> from pyhealth.datasets import TUABDataset >>> isabnormal = TUABDataset( ... root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", download=True, ... ) >>> from pyhealth.tasks import EEG_isabnormal_fn >>> EEG_abnormal_ds = isabnormal.set_task(EEG_isAbnormal_fn) >>> EEG_abnormal_ds.samples[0] { 'patient_id': 'aaaaamye', 'visit_id': 's001', 'record_id': '1', 'epoch_path': '/home/zhenlin4/.cache/pyhealth/datasets/832afe6e6e8a5c9ea5505b47e7af8125/10-1/1/0.pkl', 'label': 1 } """ samples = [] for visit in record: root, pid, visit_id, signal, label, save_path = ( visit["load_from_path"], visit["patient_id"], visit["visit_id"], visit["signal_file"], visit["label_file"], visit["save_to_path"], ) raw = mne.io.read_raw_edf(os.path.join(root, signal), preload=True) raw.resample(200) ch_name = raw.ch_names raw_data = raw.get_data() channeled_data = raw_data.copy()[:16] try: channeled_data[0] = ( raw_data[ch_name.index("EEG FP1-REF")] - raw_data[ch_name.index("EEG F7-REF")] ) channeled_data[1] = ( raw_data[ch_name.index("EEG F7-REF")] - raw_data[ch_name.index("EEG T3-REF")] ) channeled_data[2] = ( raw_data[ch_name.index("EEG T3-REF")] - raw_data[ch_name.index("EEG T5-REF")] ) channeled_data[3] = ( raw_data[ch_name.index("EEG T5-REF")] - raw_data[ch_name.index("EEG O1-REF")] ) channeled_data[4] = ( raw_data[ch_name.index("EEG FP2-REF")] - raw_data[ch_name.index("EEG F8-REF")] ) channeled_data[5] = ( raw_data[ch_name.index("EEG F8-REF")] - raw_data[ch_name.index("EEG T4-REF")] ) channeled_data[6] = ( raw_data[ch_name.index("EEG T4-REF")] - raw_data[ch_name.index("EEG T6-REF")] ) channeled_data[7] = ( raw_data[ch_name.index("EEG T6-REF")] - raw_data[ch_name.index("EEG O2-REF")] ) channeled_data[8] = ( raw_data[ch_name.index("EEG FP1-REF")] - raw_data[ch_name.index("EEG F3-REF")] ) channeled_data[9] = ( raw_data[ch_name.index("EEG F3-REF")] - raw_data[ch_name.index("EEG C3-REF")] ) channeled_data[10] = ( raw_data[ch_name.index("EEG C3-REF")] - raw_data[ch_name.index("EEG P3-REF")] ) channeled_data[11] = ( raw_data[ch_name.index("EEG P3-REF")] - raw_data[ch_name.index("EEG O1-REF")] ) channeled_data[12] = ( raw_data[ch_name.index("EEG FP2-REF")] - raw_data[ch_name.index("EEG F4-REF")] ) channeled_data[13] = ( raw_data[ch_name.index("EEG F4-REF")] - raw_data[ch_name.index("EEG C4-REF")] ) channeled_data[14] = ( raw_data[ch_name.index("EEG C4-REF")] - raw_data[ch_name.index("EEG P4-REF")] ) channeled_data[15] = ( raw_data[ch_name.index("EEG P4-REF")] - raw_data[ch_name.index("EEG O2-REF")] ) except: with open("tuab-process-error-files.txt", "a") as f: f.write(os.path.join(root, signal) + "\n") continue # get the label data_field = pid.split("_")[0] if data_field == "0" or data_field == "2": label = 1 else: label = 0 # load data for i in range(channeled_data.shape[1] // 2000): dump_path = os.path.join( save_path, pid + "_" + visit_id + "_" + str(i) + ".pkl" ) pickle.dump( {"signal": channeled_data[:, i * 2000 : (i + 1) * 2000], "label": label}, open(dump_path, "wb"), ) samples.append( { "patient_id": pid, "visit_id": visit_id, "record_id": i, "epoch_path": dump_path, "label": label, } ) return samples
[docs]def EEG_events_fn(record): """Processes a single patient for the EEG events task on TUEV. This task aims at annotating of EEG segments as one of six classes: (1) spike and sharp wave (SPSW), (2) generalized periodic epileptiform discharges (GPED), (3) periodic lateralized epileptiform discharges (PLED), (4) eye movement (EYEM), (5) artifact (ARTF) and (6) background (BCKG). Args: record: a singleton list of one subject from the TUEVDataset. The (single) record is a dictionary with the following keys: load_from_path, patient_id, visit_id, signal_file, label_file, save_to_path Returns: samples: a list of samples, each sample is a dict with patient_id, visit_id, record_id, label, offending_channel, and epoch_path (the path to the saved epoch {"signal": signal, "label": label} as key. Note that we define the task as a multiclass classification task. Examples: >>> from pyhealth.datasets import TUEVDataset >>> EEGevents = TUEVDataset( ... root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/", download=True, ... ) >>> from pyhealth.tasks import EEG_events_fn >>> EEG_events_ds = EEGevents.set_task(EEG_events_fn) >>> EEG_events_ds.samples[0] { 'patient_id': '0_00002265', 'visit_id': '00000001', 'record_id': 0, 'epoch_path': '/Users/liyanjing/.cache/pyhealth/datasets/d8f3cb92cc444d481444d3414fb5240c/0_00002265_00000001_0.pkl', 'label': 6, 'offending_channel': array([4.]) } """ samples = [] for visit in record: root, pid, visit_id, signal, label, save_path = ( visit["load_from_path"], visit["patient_id"], visit["visit_id"], visit["signal_file"], visit["label_file"], visit["save_to_path"], ) # load data try: [signals, times, event, Rawdata] = readEDF( os.path.join(root, signal) ) # event is the .rec file in the form of an array signals = convert_signals(signals, Rawdata) except (ValueError, KeyError): print("something funky happened in " + os.path.join(root, signal)) continue signals, offending_channels, labels = BuildEvents(signals, times, event) for idx, (signal, offending_channel, label) in enumerate( zip(signals, offending_channels, labels) ): dump_path = os.path.join( save_path, pid + "_" + visit_id + "_" + str(idx) + ".pkl" ) pickle.dump( {"signal": signal, "label": int(label[0])}, open(dump_path, "wb"), ) samples.append( { "patient_id": pid, "visit_id": visit_id, "record_id": idx, "epoch_path": dump_path, "label": int(label[0]), "offending_channel": offending_channel, } ) return samples
[docs]def BuildEvents(signals, times, EventData): [numEvents, z] = EventData.shape # numEvents is equal to # of rows of the .rec file fs = 250.0 [numChan, numPoints] = signals.shape features = np.zeros([numEvents, numChan, int(fs) * 5]) offending_channel = np.zeros([numEvents, 1]) # channel that had the detected thing labels = np.zeros([numEvents, 1]) offset = signals.shape[1] signals = np.concatenate([signals, signals, signals], axis=1) for i in range(numEvents): # for each event chan = int(EventData[i, 0]) # chan is channel start = np.where((times) >= EventData[i, 1])[0][0] end = np.where((times) >= EventData[i, 2])[0][0] features[i, :] = signals[ :, offset + start - 2 * int(fs) : offset + end + 2 * int(fs) ] offending_channel[i, :] = int(chan) labels[i, :] = int(EventData[i, 3]) return [features, offending_channel, labels]
[docs]def convert_signals(signals, Rawdata): signal_names = { k: v for (k, v) in zip( Rawdata.info["ch_names"], list(range(len(Rawdata.info["ch_names"]))) ) } new_signals = np.vstack( ( signals[signal_names["EEG FP1-REF"]] - signals[signal_names["EEG F7-REF"]], # 0 ( signals[signal_names["EEG F7-REF"]] - signals[signal_names["EEG T3-REF"]] ), # 1 ( signals[signal_names["EEG T3-REF"]] - signals[signal_names["EEG T5-REF"]] ), # 2 ( signals[signal_names["EEG T5-REF"]] - signals[signal_names["EEG O1-REF"]] ), # 3 ( signals[signal_names["EEG FP2-REF"]] - signals[signal_names["EEG F8-REF"]] ), # 4 ( signals[signal_names["EEG F8-REF"]] - signals[signal_names["EEG T4-REF"]] ), # 5 ( signals[signal_names["EEG T4-REF"]] - signals[signal_names["EEG T6-REF"]] ), # 6 ( signals[signal_names["EEG T6-REF"]] - signals[signal_names["EEG O2-REF"]] ), # 7 ( signals[signal_names["EEG FP1-REF"]] - signals[signal_names["EEG F3-REF"]] ), # 14 ( signals[signal_names["EEG F3-REF"]] - signals[signal_names["EEG C3-REF"]] ), # 15 ( signals[signal_names["EEG C3-REF"]] - signals[signal_names["EEG P3-REF"]] ), # 16 ( signals[signal_names["EEG P3-REF"]] - signals[signal_names["EEG O1-REF"]] ), # 17 ( signals[signal_names["EEG FP2-REF"]] - signals[signal_names["EEG F4-REF"]] ), # 18 ( signals[signal_names["EEG F4-REF"]] - signals[signal_names["EEG C4-REF"]] ), # 19 ( signals[signal_names["EEG C4-REF"]] - signals[signal_names["EEG P4-REF"]] ), # 20 (signals[signal_names["EEG P4-REF"]] - signals[signal_names["EEG O2-REF"]]), ) ) # 21 return new_signals
[docs]def readEDF(fileName): Rawdata = mne.io.read_raw_edf(fileName) signals, times = Rawdata[:] RecFile = fileName[0:-3] + "rec" eventData = np.genfromtxt(RecFile, delimiter=",") Rawdata.close() return [signals, times, eventData, Rawdata]
if __name__ == "__main__": from pyhealth.datasets import TUABDataset, TUEVDataset # dataset = TUABDataset( # root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", # dev=True, # refresh_cache=True, # ) # EEG_abnormal_ds = dataset.set_task(EEG_isAbnormal_fn) # print(EEG_abnormal_ds.samples[0]) # print(EEG_abnormal_ds.input_info) dataset = TUEVDataset( root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/", dev=True, refresh_cache=True, ) EEG_events_ds = dataset.set_task(EEG_events_fn) print(EEG_events_ds.samples[0]) print(EEG_events_ds.input_info)