import numpy as np
import torch
import mne
from typing import Any, Dict, List, Tuple
import numpy as np
from pyhealth.tasks import BaseTask
[docs]class EEGEventsTUEV(BaseTask):
"""Multi-class classification task for EEG event detection on TUEV.
For each EDF recording, this task:
1) reads the EDF
2) applies bandpass (0.1-75 Hz), notch (50 Hz), resamples to 256 Hz
3) loads the paired .rec file (same path, .edf -> .rec)
4) constructs 5-second event-centered windows (16 bipolar channels)
5) returns one sample per event
Each returned sample contains:
- "signal": np.ndarray, shape (16, 256*5)
- "offending_channel": int
- "label": int
Examples:
>>> from pyhealth.datasets import TUEVDataset
>>> from pyhealth.tasks import EEGEventsTUEV
>>> dataset = TUEVDataset(
... root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/",
... )
>>> sample_dataset = dataset.set_task(EEGEventsTUEV())
>>> sample = sample_dataset[0]
>>> print(sample['label'])
For a complete example, see `examples/conformal_eeg/tuev_eeg_quickstart.ipynb`.
"""
task_name: str = "EEG_events"
input_schema: Dict[str, str] = {"signal": "tensor", "stft": "tensor"}
output_schema: Dict[str, str] = {"label": "multiclass"}
def __init__(self,
resample_rate: float = 200,
bandpass_filter: Tuple[float, float] = (0.1, 75.0),
notch_filter: float = 50.0,
normalization: str = None, # '95th_percentile', 'div_by_100'
compute_stft: bool = True,
) -> None:
super().__init__()
self.resample_rate = resample_rate
self.bandpass_filter = bandpass_filter
self.notch_filter = notch_filter
self.normalization = normalization
self.compute_stft = compute_stft
# input_schema must reflect whether stft is produced
if not compute_stft:
self.input_schema = {"signal": "tensor"}
[docs] @staticmethod
def BuildEvents(
signals: np.ndarray, times: np.ndarray, EventData: np.ndarray,
resample_rate: float = 200,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
# Ensure 2D in case a .rec has only one row
EventData = np.atleast_2d(EventData)
numEvents, _ = EventData.shape
fs = resample_rate
numChan, _ = signals.shape
features = np.zeros([numEvents, numChan, int(fs) * 5])
offending_channel = np.zeros([numEvents, 1])
labels = np.zeros([numEvents, 1])
offset = signals.shape[1]
signals = np.concatenate([signals, signals, signals], axis=1)
for i in range(numEvents):
chan = int(EventData[i, 0])
start = np.where((times) >= EventData[i, 1])[0][0]
end = np.where((times) >= EventData[i, 2])[0][0]
features[i, :] = signals[
:, offset + start - 2 * int(fs) : offset + end + 2 * int(fs)
]
offending_channel[i, :] = int(chan)
labels[i, :] = int(EventData[i, 3])
return features, offending_channel, labels
[docs] @staticmethod
def convert_signals(signals: np.ndarray, Rawdata: mne.io.BaseRaw) -> np.ndarray:
signal_names = {
k: v
for (k, v) in zip(
Rawdata.info["ch_names"], list(range(len(Rawdata.info["ch_names"])))
)
}
new_signals = np.vstack(
(
signals[signal_names["EEG FP1-REF"]] - signals[signal_names["EEG F7-REF"]],
signals[signal_names["EEG F7-REF"]] - signals[signal_names["EEG T3-REF"]],
signals[signal_names["EEG T3-REF"]] - signals[signal_names["EEG T5-REF"]],
signals[signal_names["EEG T5-REF"]] - signals[signal_names["EEG O1-REF"]],
signals[signal_names["EEG FP2-REF"]] - signals[signal_names["EEG F8-REF"]],
signals[signal_names["EEG F8-REF"]] - signals[signal_names["EEG T4-REF"]],
signals[signal_names["EEG T4-REF"]] - signals[signal_names["EEG T6-REF"]],
signals[signal_names["EEG T6-REF"]] - signals[signal_names["EEG O2-REF"]],
signals[signal_names["EEG FP1-REF"]] - signals[signal_names["EEG F3-REF"]],
signals[signal_names["EEG F3-REF"]] - signals[signal_names["EEG C3-REF"]],
signals[signal_names["EEG C3-REF"]] - signals[signal_names["EEG P3-REF"]],
signals[signal_names["EEG P3-REF"]] - signals[signal_names["EEG O1-REF"]],
signals[signal_names["EEG FP2-REF"]] - signals[signal_names["EEG F4-REF"]],
signals[signal_names["EEG F4-REF"]] - signals[signal_names["EEG C4-REF"]],
signals[signal_names["EEG C4-REF"]] - signals[signal_names["EEG P4-REF"]],
signals[signal_names["EEG P4-REF"]] - signals[signal_names["EEG O2-REF"]],
)
)
return new_signals
[docs] @staticmethod
def readEDF(fileName: str,
resample_rate: float = 200,
bandpass_filter: Tuple[float, float] = (0.1, 75.0),
notch_filter: float = 50.0,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, mne.io.BaseRaw]:
Rawdata = mne.io.read_raw_edf(fileName, preload=True, verbose="error")
Rawdata.filter(l_freq=bandpass_filter[0], h_freq=bandpass_filter[1], verbose="error")
Rawdata.notch_filter(notch_filter, verbose="error")
Rawdata.resample(resample_rate, n_jobs=1, verbose="error")
_, times = Rawdata[:]
signals = Rawdata.get_data(units="uV")
RecFile = fileName[0:-3] + "rec"
eventData = np.genfromtxt(RecFile, delimiter=",")
Rawdata.close()
return signals, times, eventData, Rawdata
def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Processes one patient. Creates one sample per event in the .rec file.
Iterates over both 'train' and 'eval' splits.
Each sample includes a 'split' field and a precomputed 'stft' tensor.
Expected patient events to include a `signal_file` attribute pointing to an .edf file.
"""
pid = patient.patient_id
samples: List[Dict[str, Any]] = []
for split in ("train", "eval"):
events = patient.get_events(split)
for event in events:
edf_path = event.signal_file
try:
signals, times, rec, raw = self.readEDF(
edf_path, self.resample_rate, self.bandpass_filter, self.notch_filter
)
signals = self.convert_signals(signals, raw)
except (ValueError, KeyError):
continue
feats, offending_channels, labels = self.BuildEvents(
signals, times, rec, self.resample_rate
)
for idx, (signal, offending_channel, label) in enumerate(
zip(feats, offending_channels, labels)
):
if self.normalization == '95th_percentile':
signal = signal / (
np.quantile(np.abs(signal), q=0.95, axis=-1, method='linear', keepdims=True) + 1e-8
)
elif self.normalization == 'div_by_100':
signal = signal / 100
signal = torch.FloatTensor(signal)
sample = {
"patient_id": pid,
"signal_file": edf_path,
"split": split,
"signal": signal,
"offending_channel": int(offending_channel.squeeze()),
"label": int(label.squeeze()) - 1,
}
if self.compute_stft:
# get_stft_torch expects (B, C, T); unsqueeze/squeeze the batch dim
from pyhealth.models.tfm_tokenizer import get_stft_torch
sample["stft"] = get_stft_torch(signal.unsqueeze(0)).squeeze(0)
samples.append(sample)
return samples
[docs]class EEGAbnormalTUAB(BaseTask):
"""Binary classification task for abnormal EEG detection on TUAB.
For each EDF recording, this task:
1) reads the EDF
2) Applies bandpass (0.1-75 Hz), notch (50 Hz) and resamples to 200 Hz
3) Constructs 16 bipolar channels from the TCP montage of non-overlapping 10 second windows # following BIOT
4) assigns a binary label (1 = abnormal, 0 = normal) derived from the
metadata ``label`` attribute
Each returned sample contains:
- "signal": np.ndarray, shape (16, 2000)
- "label": int (0 = normal, 1 = abnormal)
Examples:
>>> from pyhealth.datasets import TUABDataset
>>> from pyhealth.tasks import EEGAbnormalTUAB
>>> dataset = TUABDataset(
... root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/",
... )
>>> sample_dataset = dataset.set_task(EEGAbnormalTUAB())
>>> sample = sample_dataset[0]
>>> print(sample['signal'].shape) # (16, 2000)
>>> print(sample['label']) # 0 or 1
"""
task_name: str = "EEG_abnormal"
input_schema: Dict[str, str] = {"signal": "tensor", "stft": "tensor"}
# NOTE: TUAB is a binary classification task (normal=0 vs abnormal=1), but the
# output schema is intentionally set to "multiclass" rather than "binary".
# Reason: PyHealth's conformal prediction methods (LABEL, ClusterLabel,
# NeighborhoodLabel, CovariateLabel) require multiclass mode — they calibrate
# prediction sets by thresholding a full (n, K) probability matrix, which is
# only produced by a softmax output (multiclass). Binary mode uses sigmoid and
# outputs (n, 1), which is incompatible with the CP calibration math.
# For a 2-class problem, 2-class softmax is mathematically equivalent to
# sigmoid, so there is no loss of correctness, just a different representation.
output_schema: Dict[str, str] = {"label": "multiclass"}
def __init__(self,
resample_rate: float = 200,
bandpass_filter: Tuple[float, float] = (0.1, 75.0),
notch_filter: float = 50.0,
normalization: str = None, # '95th_percentile', 'div_by_100'
compute_stft: bool = True,
) -> None:
super().__init__()
self.resample_rate = resample_rate
self.bandpass_filter = bandpass_filter
self.notch_filter = notch_filter
self.normalization = normalization
self.compute_stft = compute_stft
if not compute_stft:
self.input_schema = {"signal": "tensor"}
[docs] @staticmethod
def read_and_process_edf(fileName: str,
resample_rate: float = 200,
bandpass_filter: Tuple[float, float] = (0.1, 75.0),
notch_filter: float = 50.0,
) -> Tuple[np.ndarray, List[str]]:
Rawdata = mne.io.read_raw_edf(fileName, preload=True, verbose="error")
Rawdata.filter(l_freq=bandpass_filter[0], h_freq=bandpass_filter[1], verbose="error")
Rawdata.notch_filter(notch_filter, verbose="error")
Rawdata.resample(resample_rate, n_jobs=1, verbose="error")
raw_data = Rawdata.get_data(units="uV")
ch_name = Rawdata.ch_names
return raw_data, ch_name
[docs] @staticmethod
def convert_to_bipolar(raw_data: np.ndarray, ch_name: List[str]) -> np.ndarray:
"""Convert raw EEG channels to 16 bipolar montage channels following BIOT.
Returns:
np.ndarray of shape (16, n_samples)
"""
channeled_data = np.zeros((16, raw_data.shape[1]))
channeled_data[0] = (
raw_data[ch_name.index("EEG FP1-REF")]
- raw_data[ch_name.index("EEG F7-REF")]
)
channeled_data[1] = (
raw_data[ch_name.index("EEG F7-REF")]
- raw_data[ch_name.index("EEG T3-REF")]
)
channeled_data[2] = (
raw_data[ch_name.index("EEG T3-REF")]
- raw_data[ch_name.index("EEG T5-REF")]
)
channeled_data[3] = (
raw_data[ch_name.index("EEG T5-REF")]
- raw_data[ch_name.index("EEG O1-REF")]
)
channeled_data[4] = (
raw_data[ch_name.index("EEG FP2-REF")]
- raw_data[ch_name.index("EEG F8-REF")]
)
channeled_data[5] = (
raw_data[ch_name.index("EEG F8-REF")]
- raw_data[ch_name.index("EEG T4-REF")]
)
channeled_data[6] = (
raw_data[ch_name.index("EEG T4-REF")]
- raw_data[ch_name.index("EEG T6-REF")]
)
channeled_data[7] = (
raw_data[ch_name.index("EEG T6-REF")]
- raw_data[ch_name.index("EEG O2-REF")]
)
channeled_data[8] = (
raw_data[ch_name.index("EEG FP1-REF")]
- raw_data[ch_name.index("EEG F3-REF")]
)
channeled_data[9] = (
raw_data[ch_name.index("EEG F3-REF")]
- raw_data[ch_name.index("EEG C3-REF")]
)
channeled_data[10] = (
raw_data[ch_name.index("EEG C3-REF")]
- raw_data[ch_name.index("EEG P3-REF")]
)
channeled_data[11] = (
raw_data[ch_name.index("EEG P3-REF")]
- raw_data[ch_name.index("EEG O1-REF")]
)
channeled_data[12] = (
raw_data[ch_name.index("EEG FP2-REF")]
- raw_data[ch_name.index("EEG F4-REF")]
)
channeled_data[13] = (
raw_data[ch_name.index("EEG F4-REF")]
- raw_data[ch_name.index("EEG C4-REF")]
)
channeled_data[14] = (
raw_data[ch_name.index("EEG C4-REF")]
- raw_data[ch_name.index("EEG P4-REF")]
)
channeled_data[15] = (
raw_data[ch_name.index("EEG P4-REF")]
- raw_data[ch_name.index("EEG O2-REF")]
)
return channeled_data
def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Processes one patient. Creates one 10-second window sample per segment.
Iterates over both 'train' and 'eval' splits.
Each sample includes a 'split' field and, when compute_stft=True, a precomputed 'stft' tensor.
"""
pid = patient.patient_id
samples: List[Dict[str, Any]] = []
fs = self.resample_rate
for split in ("train", "eval"):
events = patient.get_events(split)
for event in events:
edf_path = event.signal_file
label = event.label
if label == 'normal':
label = 0
elif label == 'abnormal':
label = 1
try:
raw_data, ch_name = self.read_and_process_edf(
edf_path, self.resample_rate, self.bandpass_filter, self.notch_filter
)
except (ValueError, KeyError):
continue
bipolar_data = self.convert_to_bipolar(raw_data, ch_name)
num_samples = int(bipolar_data.shape[1] // (fs * 10))
for i in range(num_samples):
start = i * fs * 10
end = start + fs * 10
signal = bipolar_data[:, start:end]
if self.normalization == '95th_percentile':
signal = signal / (
np.quantile(np.abs(signal), q=0.95, axis=-1, method='linear', keepdims=True) + 1e-8
)
elif self.normalization == 'div_by_100':
signal = signal / 100
signal = torch.FloatTensor(signal)
sample = {
"patient_id": pid,
"signal_file": edf_path,
"split": split,
"signal": signal,
"label": label,
"segment_id": f'{i}',
"start_time": start,
"end_time": end,
}
if self.compute_stft:
# get_stft_torch expects (B, C, T); unsqueeze/squeeze the batch dim
from pyhealth.models.tfm_tokenizer import get_stft_torch
sample["stft"] = get_stft_torch(signal.unsqueeze(0)).squeeze(0)
samples.append(
sample
)
return samples