pyhealth.datasets.SampleEHRDataset#
This class the takes a list of samples as input (either from BaseEHRDataset.set_task() or user-provided json input), and provides a uniform interface for accessing the samples.
- class pyhealth.datasets.SampleEHRDataset(samples, code_vocs=None, dataset_name='', task_name='')[source]#
Bases:
SampleBaseDataset
Sample EHR dataset class.
- This class inherits from SampleBaseDataset and is specifically designed
for EHR datasets.
- Parameters:
- Currently, the following types of attributes are supported:
a single value. Type: int/float/str. Dim: 0.
a single vector. Type: int/float. Dim: 1.
a list of codes. Type: str. Dim: 2.
a list of vectors. Type: int/float. Dim: 2.
a list of list of codes. Type: str. Dim: 3.
a list of list of vectors. Type: int/float. Dim: 3.
- input_info#
Dict, a dict whose keys are the same as the keys in the samples, and values are the corresponding input information: - “type”: the element type of each key attribute, one of float, int, str. - “dim”: the list dimension of each key attribute, one of 0, 1, 2, 3. - “len”: the length of the vector, only valid for vector-based attributes.
- patient_to_index#
Dict[str, List[int]], a dict mapping patient_id to a list of sample indices.
- visit_to_index#
Dict[str, List[int]], a dict mapping visit_id to a list of sample indices.
Examples
>>> from pyhealth.datasets import SampleEHRDataset >>> samples = [ ... { ... "patient_id": "patient-0", ... "visit_id": "visit-0", ... "single_vector": [1, 2, 3], ... "list_codes": ["505800458", "50580045810", "50580045811"], # NDC ... "list_vectors": [[1.0, 2.55, 3.4], [4.1, 5.5, 6.0]], ... "list_list_codes": [["A05B", "A05C", "A06A"], ["A11D", "A11E"]], # ATC-4 ... "list_list_vectors": [ ... [[1.8, 2.25, 3.41], [4.50, 5.9, 6.0]], ... [[7.7, 8.5, 9.4]], ... ], ... "label": 1, ... }, ... { ... "patient_id": "patient-0", ... "visit_id": "visit-1", ... "single_vector": [1, 5, 8], ... "list_codes": [ ... "55154191800", ... "551541928", ... "55154192800", ... "705182798", ... "70518279800", ... ], ... "list_vectors": [[1.4, 3.2, 3.5], [4.1, 5.9, 1.7]], ... "list_list_codes": [["A04A", "B035", "C129"], ["A07B", "A07C"]], ... "list_list_vectors": [ ... [[1.0, 2.8, 3.3], [4.9, 5.0, 6.6]], ... [[7.7, 8.4, 1.3]], ... ], ... "label": 0, ... }, ... ] >>> dataset = SampleEHRDataset(samples=samples) >>> dataset.input_info {'patient_id': {'type': <class 'str'>, 'dim': 0}, 'visit_id': {'type': <class 'str'>, 'dim': 0}, 'single_vector': {'type': <class 'int'>, 'dim': 1, 'len': 3}, 'list_codes': {'type': <class 'str'>, 'dim': 2}, 'list_vectors': {'type': <class 'float'>, 'dim': 2, 'len': 3}, 'list_list_codes': {'type': <class 'str'>, 'dim': 3}, 'list_list_vectors': {'type': <class 'float'>, 'dim': 3, 'len': 3}, 'label': {'type': <class 'int'>, 'dim': 0}} >>> dataset.patient_to_index {'patient-0': [0, 1]} >>> dataset.visit_to_index {'visit-0': [0], 'visit-1': [1]}