import os
import numpy as np
from typing import Optional, List
from itertools import islice
from pyhealth.datasets import BaseSignalDataset
[docs]class CardiologyDataset(BaseSignalDataset):
"""Base ECG dataset for Cardiology
Dataset is available at https://physionet.org/content/challenge-2020/1.0.2/
Args:
dataset_name: name of the dataset.
root: root directory of the raw data.
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.
chosen_dataset: a list of (0,1) of length 6 indicting which datasets will be used. Default: [1, 1, 1, 1, 1, 1]
The datasets contain "cpsc_2018", "cpsc_2018_extra", "georgia", "ptb", "ptb-xl", "st_petersburg_incart".
eg. [0,1,1,1,1,1] indicates that "cpsc_2018_extra", "georgia", "ptb", "ptb-xl" and "st_petersburg_incart" will be used.
Attributes:
task: Optional[str], name of the task (e.g., "sleep staging").
Default is None.
samples: Optional[List[Dict]], a list of samples, each sample is a dict with
patient_id, record_id, and other task-specific attributes as key.
Default is None.
patient_to_index: Optional[Dict[str, List[int]]], a dict mapping patient_id to
a list of sample indices. Default is None.
visit_to_index: Optional[Dict[str, List[int]]], a dict mapping visit_id to a
list of sample indices. Default is None.
Examples:
>>> from pyhealth.datasets import CardiologyDataset
>>> dataset = CardiologyDataset(
... root="/srv/local/data/physionet.org/files/challenge-2020/1.0.2/training",
... )
>>> dataset.stat()
>>> dataset.info()
"""
def __init__(self, root: str, chosen_dataset: List[int] = [1,1,1,1,1,1], dataset_name: Optional[str] = None, dev: bool = False, refresh_cache: bool = False):
self.chosen_dataset = chosen_dataset
super().__init__(dataset_name=dataset_name, root=root, dev=dev, refresh_cache=refresh_cache)
self.root = root
self.dev = dev
self.refresh_cache = refresh_cache
[docs] def process_EEG_data(self):
# get all file names depending on user-defined dataset
dataset_lists = ["cpsc_2018", "cpsc_2018_extra", "georgia", "ptb", "ptb-xl", "st_petersburg_incart"]
all_files = []
for idx in range(6):
if self.chosen_dataset[idx] == 0:
all_files.append([])
else:
dataset_root = os.path.join(self.root, dataset_lists[idx])
dataset_samples = []
for patient in range(len(os.listdir(dataset_root)) - 1): #exclude RECORDS
patient_id = "g" + str(patient+1)
patient_root = os.path.join(dataset_root, patient_id)
dataset_samples.append([i.split(".")[0] for i in os.listdir(patient_root) if i != "RECORDS" and i != "index.html"])
all_files.append(dataset_samples) #[dataset:[patient:[sample1, sample2...]...]...]
#print(all_files)
# get all patient ids
patient_ids = []
for dataset_idx in range(len(all_files)):
if all_files[dataset_idx] != []:
for patient_idx in range(len(all_files[dataset_idx])):
cur_id = "{}_{}".format(dataset_idx, patient_idx)
patient_ids.append(cur_id)
#print(patient_ids)
if self.dev:
patient_ids = patient_ids[:5]
# get patient to record maps
# - key: pid:
# - value: [{"load_from_path": None, "signal_file": None, "label_file": None, "save_to_path": None}, ...]
patients = {
pid: []
for pid in patient_ids
}
for dataset_idx in range(len(all_files)):
if all_files[dataset_idx] != []:
for patient_idx in range(len(all_files[dataset_idx])):
pid = "{}_{}".format(dataset_idx, patient_idx)
if pid in patient_ids:
for sample in all_files[dataset_idx][patient_idx]:
patients[pid].append({
"load_from_path": os.path.join(self.root, dataset_lists[dataset_idx], "g{}".format(patient_idx+1)),
"patient_id": pid,
"signal_file": sample + ".mat",
"label_file": sample + ".hea",
"save_to_path": self.filepath,
})
return patients
if __name__ == "__main__":
dataset = CardiologyDataset(
root="/srv/local/data/physionet.org/files/challenge-2020/1.0.2/training",
dev=True,
refresh_cache=True,
)
dataset.stat()
dataset.info()
# the number of records for the first patient
keys = list(dataset.patients.keys())
print(len(dataset.patients[keys[0]]))