import os
from typing import Optional, List, Dict, Tuple, Union
import pandas as pd
from pyhealth.data import Event, Visit, Patient
from pyhealth.datasets import BaseEHRDataset
from pyhealth.datasets.utils import strptime
# TODO: add other tables
[docs]class MIMIC3Dataset(BaseEHRDataset):
"""Base dataset for MIMIC-III dataset.
The MIMIC-III dataset is a large dataset of de-identified health records of ICU
patients. The dataset is available at https://mimic.physionet.org/.
The basic information is stored in the following tables:
- PATIENTS: defines a patient in the database, SUBJECT_ID.
- ADMISSIONS: defines a patient's hospital admission, HADM_ID.
We further support the following tables:
- DIAGNOSES_ICD: contains ICD-9 diagnoses (ICD9CM code) for patients.
- PROCEDURES_ICD: contains ICD-9 procedures (ICD9PROC code) for patients.
- PRESCRIPTIONS: contains medication related order entries (NDC code)
for patients.
- LABEVENTS: contains laboratory measurements (MIMIC3_ITEMID code)
for patients
Args:
dataset_name: name of the dataset.
root: root directory of the raw data (should contain many csv files).
tables: list of tables to be loaded (e.g., ["DIAGNOSES_ICD", "PROCEDURES_ICD"]).
code_mapping: a dictionary containing the code mapping information.
The key is a str of the source code vocabulary and the value is of
two formats:
(1) a str of the target code vocabulary;
(2) a tuple with two elements. The first element is a str of the
target code vocabulary and the second element is a dict with
keys "source_kwargs" or "target_kwargs" and values of the
corresponding kwargs for the `CrossMap.map()` method.
Default is empty dict, which means the original code will be used.
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.
Attributes:
task: Optional[str], name of the task (e.g., "mortality prediction").
Default is None.
samples: Optional[List[Dict]], a list of samples, each sample is a dict with
patient_id, visit_id, and other task-specific attributes as key.
Default is None.
patient_to_index: Optional[Dict[str, List[int]]], a dict mapping patient_id to
a list of sample indices. Default is None.
visit_to_index: Optional[Dict[str, List[int]]], a dict mapping visit_id to a
list of sample indices. Default is None.
Examples:
>>> from pyhealth.datasets import MIMIC3Dataset
>>> dataset = MIMIC3Dataset(
... root="/srv/local/data/physionet.org/files/mimiciii/1.4",
... tables=["DIAGNOSES_ICD", "PRESCRIPTIONS"],
... code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
... )
>>> dataset.stat()
>>> dataset.info()
"""
[docs] def parse_basic_info(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses PATIENTS and ADMISSIONS tables.
Will be called in `self.parse_tables()`
Docs:
- PATIENTS: https://mimic.mit.edu/docs/iii/tables/patients/
- ADMISSIONS: https://mimic.mit.edu/docs/iii/tables/admissions/
Args:
patients: a dict of `Patient` objects indexed by patient_id which is updated with the mimic-3 table result.
Returns:
The updated patients dict.
"""
# read patients table
patients_df = pd.read_csv(
os.path.join(self.root, "PATIENTS.csv"),
dtype={"SUBJECT_ID": str},
nrows=1000 if self.dev else None,
)
# read admissions table
admissions_df = pd.read_csv(
os.path.join(self.root, "ADMISSIONS.csv"),
dtype={"SUBJECT_ID": str, "HADM_ID": str},
)
# merge patient and admission tables
df = pd.merge(patients_df, admissions_df, on="SUBJECT_ID", how="inner")
# sort by admission and discharge time
df = df.sort_values(["SUBJECT_ID", "ADMITTIME", "DISCHTIME"], ascending=True)
# group by patient
df_group = df.groupby("SUBJECT_ID")
# parallel unit of basic information (per patient)
def basic_unit(p_id, p_info):
patient = Patient(
patient_id=p_id,
birth_datetime=strptime(p_info["DOB"].values[0]),
death_datetime=strptime(p_info["DOD_HOSP"].values[0]),
gender=p_info["GENDER"].values[0],
ethnicity=p_info["ETHNICITY"].values[0],
)
# load visits
for v_id, v_info in p_info.groupby("HADM_ID"):
visit = Visit(
visit_id=v_id,
patient_id=p_id,
encounter_time=strptime(v_info["ADMITTIME"].values[0]),
discharge_time=strptime(v_info["DISCHTIME"].values[0]),
discharge_status=v_info["HOSPITAL_EXPIRE_FLAG"].values[0],
insurance=v_info["INSURANCE"].values[0],
language=v_info["LANGUAGE"].values[0],
religion=v_info["RELIGION"].values[0],
marital_status=v_info["MARITAL_STATUS"].values[0],
ethnicity=v_info["ETHNICITY"].values[0],
)
# add visit
patient.add_visit(visit)
return patient
# parallel apply
df_group = df_group.parallel_apply(
lambda x: basic_unit(x.SUBJECT_ID.unique()[0], x)
)
# summarize the results
for pat_id, pat in df_group.items():
patients[pat_id] = pat
return patients
[docs] def parse_diagnoses_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses DIAGNOSES_ICD table.
Will be called in `self.parse_tables()`
Docs:
- DIAGNOSES_ICD: https://mimic.mit.edu/docs/iii/tables/diagnoses_icd/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
Note:
MIMIC-III does not provide specific timestamps in DIAGNOSES_ICD
table, so we set it to None.
"""
table = "DIAGNOSES_ICD"
self.code_vocs["conditions"] = "ICD9CM"
# read table
df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"),
dtype={"SUBJECT_ID": str, "HADM_ID": str, "ICD9_CODE": str},
)
# drop records of the other patients
df = df[df["SUBJECT_ID"].isin(patients.keys())]
# drop rows with missing values
df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "ICD9_CODE"])
# sort by sequence number (i.e., priority)
df = df.sort_values(["SUBJECT_ID", "HADM_ID", "SEQ_NUM"], ascending=True)
# group by patient and visit
group_df = df.groupby("SUBJECT_ID")
# parallel unit of diagnosis (per patient)
def diagnosis_unit(p_id, p_info):
events = []
for v_id, v_info in p_info.groupby("HADM_ID"):
for code in v_info["ICD9_CODE"]:
event = Event(
code=code,
table=table,
vocabulary="ICD9CM",
visit_id=v_id,
patient_id=p_id,
)
events.append(event)
return events
# parallel apply
group_df = group_df.parallel_apply(
lambda x: diagnosis_unit(x.SUBJECT_ID.unique()[0], x)
)
# summarize the results
patients = self._add_events_to_patient_dict(patients, group_df)
return patients
[docs] def parse_procedures_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses PROCEDURES_ICD table.
Will be called in `self.parse_tables()`
Docs:
- PROCEDURES_ICD: https://mimic.mit.edu/docs/iii/tables/procedures_icd/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
Note:
MIMIC-III does not provide specific timestamps in PROCEDURES_ICD
table, so we set it to None.
"""
table = "PROCEDURES_ICD"
self.code_vocs["procedures"] = "ICD9PROC"
# read table
df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"),
dtype={"SUBJECT_ID": str, "HADM_ID": str, "ICD9_CODE": str},
)
# drop records of the other patients
df = df[df["SUBJECT_ID"].isin(patients.keys())]
# drop rows with missing values
df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "SEQ_NUM", "ICD9_CODE"])
# sort by sequence number (i.e., priority)
df = df.sort_values(["SUBJECT_ID", "HADM_ID", "SEQ_NUM"], ascending=True)
# group by patient and visit
group_df = df.groupby("SUBJECT_ID")
# parallel unit of procedure (per patient)
def procedure_unit(p_id, p_info):
events = []
for v_id, v_info in p_info.groupby("HADM_ID"):
for code in v_info["ICD9_CODE"]:
event = Event(
code=code,
table=table,
vocabulary="ICD9PROC",
visit_id=v_id,
patient_id=p_id,
)
events.append(event)
return events
# parallel apply
group_df = group_df.parallel_apply(
lambda x: procedure_unit(x.SUBJECT_ID.unique()[0], x)
)
# summarize the results
patients = self._add_events_to_patient_dict(patients, group_df)
return patients
[docs] def parse_prescriptions(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses PRESCRIPTIONS table.
Will be called in `self.parse_tables()`
Docs:
- PRESCRIPTIONS: https://mimic.mit.edu/docs/iii/tables/prescriptions/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
"""
table = "PRESCRIPTIONS"
self.code_vocs["drugs"] = "NDC"
# read table
df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"),
low_memory=False,
dtype={"SUBJECT_ID": str, "HADM_ID": str, "NDC": str},
)
# drop records of the other patients
df = df[df["SUBJECT_ID"].isin(patients.keys())]
# drop rows with missing values
df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "NDC"])
# sort by start date and end date
df = df.sort_values(
["SUBJECT_ID", "HADM_ID", "STARTDATE", "ENDDATE"], ascending=True
)
# group by patient and visit
group_df = df.groupby("SUBJECT_ID")
# parallel unit for prescription (per patient)
def prescription_unit(p_id, p_info):
events = []
for v_id, v_info in p_info.groupby("HADM_ID"):
for timestamp, code in zip(v_info["STARTDATE"], v_info["NDC"]):
event = Event(
code=code,
table=table,
vocabulary="NDC",
visit_id=v_id,
patient_id=p_id,
timestamp=strptime(timestamp),
)
events.append(event)
return events
# parallel apply
group_df = group_df.parallel_apply(
lambda x: prescription_unit(x.SUBJECT_ID.unique()[0], x)
)
# summarize the results
patients = self._add_events_to_patient_dict(patients, group_df)
return patients
[docs] def parse_labevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses LABEVENTS table.
Will be called in `self.parse_tables()`
Docs:
- LABEVENTS: https://mimic.mit.edu/docs/iii/tables/labevents/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
"""
table = "LABEVENTS"
self.code_vocs["labs"] = "MIMIC3_ITEMID"
# read table
df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"),
dtype={"SUBJECT_ID": str, "HADM_ID": str, "ITEMID": str},
)
# drop records of the other patients
df = df[df["SUBJECT_ID"].isin(patients.keys())]
# drop rows with missing values
df = df.dropna(subset=["SUBJECT_ID", "HADM_ID", "ITEMID"])
# sort by charttime
df = df.sort_values(["SUBJECT_ID", "HADM_ID", "CHARTTIME"], ascending=True)
# group by patient and visit
group_df = df.groupby("SUBJECT_ID")
# parallel unit for lab (per patient)
def lab_unit(p_id, p_info):
events = []
for v_id, v_info in p_info.groupby("HADM_ID"):
for timestamp, code in zip(v_info["CHARTTIME"], v_info["ITEMID"]):
event = Event(
code=code,
table=table,
vocabulary="MIMIC3_ITEMID",
visit_id=v_id,
patient_id=p_id,
timestamp=strptime(timestamp),
)
events.append(event)
return events
# parallel apply
group_df = group_df.parallel_apply(
lambda x: lab_unit(x.SUBJECT_ID.unique()[0], x)
)
# summarize the results
patients = self._add_events_to_patient_dict(patients, group_df)
return patients
if __name__ == "__main__":
dataset = MIMIC3Dataset(
root="https://storage.googleapis.com/pyhealth/mimiciii-demo/1.4/",
tables=[
"DIAGNOSES_ICD",
"PROCEDURES_ICD",
"PRESCRIPTIONS",
"LABEVENTS",
],
code_mapping={"NDC": "ATC"},
dev=True,
refresh_cache=True,
)
dataset.stat()
dataset.info()
# dataset = MIMIC3Dataset(
# root="/srv/local/data/physionet.org/files/mimiciii/1.4",
# tables=["DIAGNOSES_ICD", "PRESCRIPTIONS"],
# dev=True,
# code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
# refresh_cache=False,
# )
# print(dataset.stat())
# print(dataset.available_tables)
# print(list(dataset.patients.values())[4])