Source code for pyhealth.datasets.sample_dataset

from collections import Counter
from typing import Dict, List
import pickle

from torch.utils.data import Dataset

from pyhealth.datasets.utils import list_nested_levels, flatten_list


[docs]class SampleBaseDataset(Dataset): """Sample base dataset class. This class the takes a list of samples as input (either from `BaseDataset.set_task()` or user-provided input), and provides a uniform interface for accessing the samples. Args: samples: a list of samples, each sample is a dict with patient_id, visit_id, and other task-specific attributes as key. dataset_name: the name of the dataset. Default is None. task_name: the name of the task. Default is None. """ def __init__(self, samples: List[Dict], dataset_name="", task_name=""): self.samples = samples self.dataset_name: str = dataset_name self.task_name: str = task_name self.type_ = "base" def __getitem__(self, index) -> Dict: """Returns a sample by index. Returns: Dict, a dict with patient_id, visit_id/record_id, and other task-specific attributes as key. Conversion to index/tensor will be done in the model. """ return self.samples[index] def __str__(self): """Prints some information of the dataset.""" return f"Sample dataset {self.dataset_name} {self.task_name}" def __len__(self): """Returns the number of samples in the dataset.""" return len(self.samples)
[docs] def get_all_tokens( self, key: str, remove_duplicates: bool = True, sort: bool = True ) -> List[str]: """Gets all tokens with a specific key in the samples. Args: key: the key of the tokens in the samples. remove_duplicates: whether to remove duplicates. Default is True. sort: whether to sort the tokens by alphabet order. Default is True. Returns: tokens: a list of tokens. """ input_type = self.input_info[key]["type"] input_dim = self.input_info[key]["dim"] if input_type in [float, int]: assert input_dim == 0, f"Cannot get tokens for vector with key {key}" tokens = [] for sample in self.samples: if input_dim == 0: # a single value tokens.append(sample[key]) elif input_dim == 2: # a list of codes tokens.extend(sample[key]) elif input_dim == 3: # a list of list of codes tokens.extend(flatten_list(sample[key])) else: raise NotImplementedError if remove_duplicates: tokens = list(set(tokens)) if sort: tokens.sort() return tokens
[docs]class SampleSignalDataset(SampleBaseDataset): """Sample signal dataset class. This class the takes a list of samples as input (either from `BaseDataset.set_task()` or user-provided input), and provides a uniform interface for accessing the samples. Args: samples: a list of samples, each sample is a dict with patient_id, record_id, and other task-specific attributes as key. classes: a list of classes, e.g., ["W", "1", "2", "3", "R"]. dataset_name: the name of the dataset. Default is None. task_name: the name of the task. Default is None. """ def __init__(self, samples: List[Dict], dataset_name="", task_name=""): super().__init__(samples, dataset_name, task_name) self.patient_to_index: Dict[str, List[int]] = self._index_patient() self.record_to_index: Dict[str, List[int]] = self._index_record() self.input_info: Dict = self._validate() self.type_ = "signal" def _index_patient(self) -> Dict[str, List[int]]: """Helper function which indexes the samples by patient_id. Will be called in `self.__init__()`. Returns: patient_to_index: Dict[str, int], a dict mapping patient_id to a list of sample indices. """ patient_to_index = {} for idx, sample in enumerate(self.samples): patient_to_index.setdefault(sample["patient_id"], []).append(idx) return patient_to_index def _index_record(self) -> Dict[str, List[int]]: """Helper function which indexes the samples by record_id. Will be called in `self.__init__()`. Returns: visit_to_index: Dict[str, int], a dict mapping record_id to a list of sample indices. """ record_to_index = {} for idx, sample in enumerate(self.samples): record_to_index.setdefault(sample["record_id"], []).append(idx) return record_to_index def _validate(self) -> Dict: """Helper function which gets the input information of each attribute. Will be called in `self.__init__()`. Returns: input_info: Dict, a dict whose keys are the same as the keys in the samples, and values are the corresponding input information: - "length": the length of the input. - "n_channels": the number of channels of the input. """ input_info = {} # get signal info sample_path_0 = self.samples[0]["epoch_path"] sample = pickle.load(open(sample_path_0, "rb")) n_channels, length = sample["signal"].shape input_info["signal"] = {"length": length, "n_channels": n_channels} # get label signal info input_info["label"] = {"type": str, "dim": 0} return input_info def __getitem__(self, index) -> Dict: """Returns a sample by index. Returns: Dict, a dict with patient_id, visit_id/record_id, and other task-specific attributes as key. Conversion to index/tensor will be done in the model. """ sample = self.samples[index] loaded_sample = pickle.load(open(sample["epoch_path"], "rb")) cur_sample = sample.copy() cur_sample.update(loaded_sample) cur_sample.pop("epoch_path", None) return cur_sample
[docs] def stat(self) -> str: """Returns some statistics of the task-specific dataset.""" lines = list() lines.append(f"Statistics of sample dataset:") lines.append(f"\t- Dataset: {self.dataset_name}") lines.append(f"\t- Task: {self.task_name}") lines.append(f"\t- Number of samples: {len(self)}") num_patients = len(set([sample["patient_id"] for sample in self.samples])) lines.append(f"\t- Number of patients: {num_patients}") num_records = len(set([sample["record_id"] for sample in self.samples])) lines.append(f"\t- Number of visits: {num_records}") lines.append( f"\t- Number of samples per patient: {len(self) / num_patients:.4f}" ) print("\n".join(lines)) return "\n".join(lines)
[docs]class SampleEHRDataset(SampleBaseDataset): """Sample EHR dataset class. This class inherits from `SampleBaseDataset` and is specifically designed for EHR datasets. Args: samples: a list of samples, each sample is a dict with patient_id, visit_id, and other task-specific attributes as key. dataset_name: the name of the dataset. Default is None. task_name: the name of the task. Default is None. Currently, the following types of attributes are supported: - a single value. Type: int/float/str. Dim: 0. - a single vector. Type: int/float. Dim: 1. - a list of codes. Type: str. Dim: 2. - a list of vectors. Type: int/float. Dim: 2. - a list of list of codes. Type: str. Dim: 3. - a list of list of vectors. Type: int/float. Dim: 3. Attributes: input_info: Dict, a dict whose keys are the same as the keys in the samples, and values are the corresponding input information: - "type": the element type of each key attribute, one of float, int, str. - "dim": the list dimension of each key attribute, one of 0, 1, 2, 3. - "len": the length of the vector, only valid for vector-based attributes. patient_to_index: Dict[str, List[int]], a dict mapping patient_id to a list of sample indices. visit_to_index: Dict[str, List[int]], a dict mapping visit_id to a list of sample indices. Examples: >>> from pyhealth.datasets import SampleEHRDataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "single_vector": [1, 2, 3], ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 ... "list_list_vectors": [ ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], ... [[7.7, 8.5, 9.4]], ... ], ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "single_vector": [1, 5, 8], ... "list_codes": [ ... "55154191800", ... "551541928", ... "55154192800", ... "705182798", ... "70518279800", ... ], ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7]], ... "list_list_codes": [["A04A", "B035", "C129"], ["A07B", "A07C"]], ... "list_list_vectors": [ ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6]], ... [[7.7, 8.4, 1.3]], ... ], ... "label": 0, ... }, ... ] >>> dataset = SampleEHRDataset(samples=samples) >>> dataset.input_info {'patient_id': {'type': <class 'str'>, 'dim': 0}, 'visit_id': {'type': <class 'str'>, 'dim': 0}, 'single_vector': {'type': <class 'int'>, 'dim': 1, 'len': 3}, 'list_codes': {'type': <class 'str'>, 'dim': 2}, 'list_vectors': {'type': <class 'float'>, 'dim': 2, 'len': 3}, 'list_list_codes': {'type': <class 'str'>, 'dim': 3}, 'list_list_vectors': {'type': <class 'float'>, 'dim': 3, 'len': 3}, 'label': {'type': <class 'int'>, 'dim': 0}} >>> dataset.patient_to_index {'patient-0': [0, 1]} >>> dataset.visit_to_index {'visit-0': [0], 'visit-1': [1]} """ def __init__(self, samples: List[Dict], code_vocs=None, dataset_name="", task_name=""): super().__init__(samples, dataset_name, task_name) self.samples = samples self.code_vocs = code_vocs self.input_info: Dict = self._validate() self.patient_to_index: Dict[str, List[int]] = self._index_patient() self.visit_to_index: Dict[str, List[int]] = self._index_visit() self.type_ = "ehr" def _validate(self) -> Dict: """Helper function which validates the samples. Will be called in `self.__init__()`. Returns: input_info: Dict, a dict whose keys are the same as the keys in the samples, and values are the corresponding input information: - "type": the element type of each key attribute, one of float, int, str. - "dim": the list dimension of each key attribute, one of 0, 1, 2, 3. - "len": the length of the vector, only valid for vector-based attributes. """ """ 1. Check if all samples are of type dict. """ assert all( [isinstance(s, dict) for s in self.samples], ), "Each sample should be a dict" keys = self.samples[0].keys() """ 2. Check if all samples have the same keys. """ assert all( [set(s.keys()) == set(keys) for s in self.samples] ), "All samples should have the same keys" """ 3. Check if "patient_id" and "visit_id" are in the keys.""" assert "patient_id" in keys, "patient_id should be in the keys" assert "visit_id" in keys, "visit_id should be in the keys" """ 4. For each key, check if it is either: - a single value - a single vector - a list of codes - a list of vectors - a list of list of codes - a list of list of vectors Note that a value is either float, int, or str; a vector is a list of float or int; and a code is str. """ # record input information for each key input_info = {} for key in keys: """ 4.1. Check nested list level: all samples should either all be - a single value (level=0) - a single vector (level=1) - a list of codes (level=1) - a list of vectors (level=2) - a list of list of codes (level=2) - a list of list of vectors (level=3) """ levels = set([list_nested_levels(s[key]) for s in self.samples]) assert ( len(levels) == 1 and len(list(levels)[0]) == 1 ), f"Key {key} has mixed nested list levels across samples" level = levels.pop()[0] assert level in [ 0, 1, 2, 3, ], f"Key {key} has unsupported nested list level across samples" # flatten the list if level == 0: flattened_values = [s[key] for s in self.samples] elif level == 1: flattened_values = [i for s in self.samples for i in s[key]] elif level == 2: flattened_values = [j for s in self.samples for i in s[key] for j in i] else: flattened_values = [ k for s in self.samples for i in s[key] for j in i for k in j ] """ 4.2. Check type: the basic type of each element should be float, int, or str. """ types = set([type(v) for v in flattened_values]) assert ( types == set([str]) or len(types.difference(set([int, float]))) == 0 ), f"Key {key} has mixed or unsupported types ({types}) across samples" type_ = types.pop() """ 4.3. Combined level and type check. """ if level == 0: # a single value input_info[key] = {"type": type_, "dim": 0} elif level == 1: # a single vector or a list of codes if type_ in [float, int]: # a single vector lens = set([len(s[key]) for s in self.samples]) assert len(lens) == 1, f"Key {key} has vectors of different lengths" input_info[key] = {"type": type_, "dim": 1, "len": lens.pop()} else: # a list of codes # note that dim is different from level here input_info[key] = {"type": type_, "dim": 2} elif level == 2: # a list of vectors or a list of list of codes if type_ in [float, int]: lens = set([len(i) for s in self.samples for i in s[key]]) assert len(lens) == 1, f"Key {key} has vectors of different lengths" input_info[key] = {"type": type_, "dim": 2, "len": lens.pop()} else: # a list of list of codes # note that dim is different from level here input_info[key] = {"type": type_, "dim": 3} else: # a list of list of vectors assert type_ in [ float, int, ], f"Key {key} has unsupported type across samples" lens = set([len(j) for s in self.samples for i in s[key] for j in i]) assert len(lens) == 1, f"Key {key} has vectors of different lengths" input_info[key] = {"type": type_, "dim": 3, "len": lens.pop()} return input_info def _index_patient(self) -> Dict[str, List[int]]: """Helper function which indexes the samples by patient_id. Will be called in `self.__init__()`. Returns: patient_to_index: Dict[str, int], a dict mapping patient_id to a list of sample indices. """ patient_to_index = {} for idx, sample in enumerate(self.samples): patient_to_index.setdefault(sample["patient_id"], []).append(idx) return patient_to_index def _index_visit(self) -> Dict[str, List[int]]: """Helper function which indexes the samples by visit_id. Will be called in `self.__init__()`. Returns: visit_to_index: Dict[str, int], a dict mapping visit_id to a list of sample indices. """ visit_to_index = {} for idx, sample in enumerate(self.samples): visit_to_index.setdefault(sample["visit_id"], []).append(idx) return visit_to_index @property def available_keys(self) -> List[str]: """Returns a list of available keys for the dataset. Returns: List of available keys. """ keys = self.samples[0].keys() return list(keys)
[docs] def get_distribution_tokens(self, key: str) -> Dict[str, int]: """Gets the distribution of tokens with a specific key in the samples. Args: key: the key of the tokens in the samples. Returns: distribution: a dict mapping token to count. """ tokens = self.get_all_tokens(key, remove_duplicates=False, sort=False) counter = Counter(tokens) return counter
[docs] def stat(self) -> str: """Returns some statistics of the task-specific dataset.""" lines = list() lines.append(f"Statistics of sample dataset:") lines.append(f"\t- Dataset: {self.dataset_name}") lines.append(f"\t- Task: {self.task_name}") lines.append(f"\t- Number of samples: {len(self)}") num_patients = len(set([sample["patient_id"] for sample in self.samples])) lines.append(f"\t- Number of patients: {num_patients}") num_visits = len(set([sample["visit_id"] for sample in self.samples])) lines.append(f"\t- Number of visits: {num_visits}") lines.append( f"\t- Number of visits per patient: {len(self) / num_patients:.4f}" ) for key in self.samples[0]: if key in ["patient_id", "visit_id"]: continue input_type = self.input_info[key]["type"] input_dim = self.input_info[key]["dim"] if input_dim <= 1: # a single value or vector num_events = [1 for sample in self.samples] elif input_dim == 2: # a list num_events = [len(sample[key]) for sample in self.samples] elif input_dim == 3: # a list of list num_events = [len(flatten_list(sample[key])) for sample in self.samples] else: raise NotImplementedError lines.append(f"\t- {key}:") lines.append( f"\t\t- Number of {key} per sample: " f"{sum(num_events) / len(num_events):.4f}" ) if input_type == str or input_dim == 0: # single value or code-based lines.append( f"\t\t- Number of unique {key}: {len(self.get_all_tokens(key))}" ) distribution = self.get_distribution_tokens(key) top10 = sorted(distribution.items(), key=lambda x: x[1], reverse=True)[ :10 ] lines.append(f"\t\t- Distribution of {key} (Top-10): {top10}") else: # vector-based vector = self.samples[0][key] lines.append(f"\t\t- Length of {key}: {self.input_info[key]['len']}") print("\n".join(lines)) return "\n".join(lines)
if __name__ == "__main__": samples = [ { "patient_id": "patient-0", "visit_id": "visit-0", "single_vector": [1, 2, 3], "list_codes": ["505800458", "50580045810", "50580045811"], # NDC "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 "list_list_vectors": [ [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], [[7.7, 8.5, 9.4]], ], "label": 1, }, { "patient_id": "patient-0", "visit_id": "visit-1", "single_vector": [1, 5, 8], "list_codes": [ "55154191800", "551541928", "55154192800", "705182798", "70518279800", ], "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7]], "list_list_codes": [["A04A", "B035", "C129"], ["A07B", "A07C"]], "list_list_vectors": [ [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6]], [[7.7, 8.4, 1.3]], ], "label": 0, }, ] dataset = SampleEHRDataset(samples=samples) dataset.stat() data = iter(dataset) print(next(data))