import os
from typing import Optional, List, Dict, Union, Tuple
import pandas as pd
from pyhealth.data import Event, Visit, Patient
from pyhealth.datasets import BaseEHRDataset
from pyhealth.datasets.utils import strptime
# TODO: add other tables
[docs]class MIMIC4Dataset(BaseEHRDataset):
"""Base dataset for MIMIC-IV dataset.
The MIMIC-IV dataset is a large dataset of de-identified health records of ICU
patients. The dataset is available at https://mimic.physionet.org/.
The basic information is stored in the following tables:
- patients: defines a patient in the database, subject_id.
- admission: define a patient's hospital admission, hadm_id.
We further support the following tables:
- diagnoses_icd: contains ICD diagnoses (ICD9CM and ICD10CM code)
for patients.
- procedures_icd: contains ICD procedures (ICD9PROC and ICD10PROC
code) for patients.
- prescriptions: contains medication related order entries (NDC code)
for patients.
- labevents: contains laboratory measurements (MIMIC4_ITEMID code)
for patients
Args:
dataset_name: name of the dataset.
root: root directory of the raw data (should contain many csv files).
tables: list of tables to be loaded (e.g., ["DIAGNOSES_ICD", "PROCEDURES_ICD"]).
code_mapping: a dictionary containing the code mapping information.
The key is a str of the source code vocabulary and the value is of
two formats:
(1) a str of the target code vocabulary;
(2) a tuple with two elements. The first element is a str of the
target code vocabulary and the second element is a dict with
keys "source_kwargs" or "target_kwargs" and values of the
corresponding kwargs for the `CrossMap.map()` method.
Default is empty dict, which means the original code will be used.
dev: whether to enable dev mode (only use a small subset of the data).
Default is False.
refresh_cache: whether to refresh the cache; if true, the dataset will
be processed from scratch and the cache will be updated. Default is False.
Attributes:
task: Optional[str], name of the task (e.g., "mortality prediction").
Default is None.
samples: Optional[List[Dict]], a list of samples, each sample is a dict with
patient_id, visit_id, and other task-specific attributes as key.
Default is None.
patient_to_index: Optional[Dict[str, List[int]]], a dict mapping patient_id to
a list of sample indices. Default is None.
visit_to_index: Optional[Dict[str, List[int]]], a dict mapping visit_id to a
list of sample indices. Default is None.
Examples:
>>> from pyhealth.datasets import MIMIC4Dataset
>>> dataset = MIMIC4Dataset(
... root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp",
... tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
... code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
... )
>>> dataset.stat()
>>> dataset.info()
"""
[docs] def parse_basic_info(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper functions which parses patients and admissions tables.
Will be called in `self.parse_tables()`
Docs:
- patients:https://mimic.mit.edu/docs/iv/modules/hosp/patients/
- admissions: https://mimic.mit.edu/docs/iv/modules/hosp/admissions/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
"""
# read patients table
patients_df = pd.read_csv(
os.path.join(self.root, "patients.csv"),
dtype={"subject_id": str},
nrows=1000 if self.dev else None,
)
# read admissions table
admissions_df = pd.read_csv(
os.path.join(self.root, "admissions.csv"),
dtype={"subject_id": str, "hadm_id": str},
)
# merge patients and admissions tables
df = pd.merge(patients_df, admissions_df, on="subject_id", how="inner")
# sort by admission and discharge time
df = df.sort_values(["subject_id", "admittime", "dischtime"], ascending=True)
# group by patient
df_group = df.groupby("subject_id")
# parallel unit of basic information (per patient)
def basic_unit(p_id, p_info):
# no exact birth datetime in MIMIC-IV
# use anchor_year and anchor_age to approximate birth datetime
anchor_year = int(p_info["anchor_year"].values[0])
anchor_age = int(p_info["anchor_age"].values[0])
birth_year = anchor_year - anchor_age
patient = Patient(
patient_id=p_id,
# no exact month, day, and time, use Jan 1st, 00:00:00
birth_datetime=strptime(str(birth_year)),
# no exact time, use 00:00:00
death_datetime=strptime(p_info["dod"].values[0]),
gender=p_info["gender"].values[0],
ethnicity=p_info["race"].values[0],
anchor_year_group=p_info["anchor_year_group"].values[0],
)
# load visits
for v_id, v_info in p_info.groupby("hadm_id"):
visit = Visit(
visit_id=v_id,
patient_id=p_id,
encounter_time=strptime(v_info["admittime"].values[0]),
discharge_time=strptime(v_info["dischtime"].values[0]),
discharge_status=v_info["hospital_expire_flag"].values[0],
)
# add visit
patient.add_visit(visit)
return patient
# parallel apply
df_group = df_group.parallel_apply(
lambda x: basic_unit(x.subject_id.unique()[0], x)
)
# summarize the results
for pat_id, pat in df_group.items():
patients[pat_id] = pat
return patients
[docs] def parse_diagnoses_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses diagnosis_icd table.
Will be called in `self.parse_tables()`
Docs:
- diagnosis_icd: https://mimic.mit.edu/docs/iv/modules/hosp/diagnoses_icd/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
Note:
MIMIC-IV does not provide specific timestamps in diagnoses_icd
table, so we set it to None.
"""
table = "diagnoses_icd"
# read table
df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"),
dtype={"subject_id": str, "hadm_id": str, "icd_code": str},
)
# drop rows with missing values
df = df.dropna(subset=["subject_id", "hadm_id", "icd_code", "icd_version"])
# sort by sequence number (i.e., priority)
df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True)
# group by patient and visit
group_df = df.groupby("subject_id")
# parallel unit of diagnosis (per patient)
def diagnosis_unit(p_id, p_info):
events = []
# iterate over each patient and visit
for v_id, v_info in p_info.groupby("hadm_id"):
for code, version in zip(v_info["icd_code"], v_info["icd_version"]):
event = Event(
code=code,
table=table,
vocabulary=f"ICD{version}CM",
visit_id=v_id,
patient_id=p_id,
)
events.append(event)
return events
# parallel apply
group_df = group_df.parallel_apply(
lambda x: diagnosis_unit(x.subject_id.unique()[0], x)
)
# summarize the results
patients = self._add_events_to_patient_dict(patients, group_df)
return patients
[docs] def parse_procedures_icd(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses procedures_icd table.
Will be called in `self.parse_tables()`
Docs:
- procedures_icd: https://mimic.mit.edu/docs/iv/modules/hosp/procedures_icd/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
Note:
MIMIC-IV does not provide specific timestamps in procedures_icd
table, so we set it to None.
"""
table = "procedures_icd"
# read table
df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"),
dtype={"subject_id": str, "hadm_id": str, "icd_code": str},
)
# drop rows with missing values
df = df.dropna(subset=["subject_id", "hadm_id", "icd_code", "icd_version"])
# sort by sequence number (i.e., priority)
df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True)
# group by patient and visit
group_df = df.groupby("subject_id")
# parallel unit of procedure (per patient)
def procedure_unit(p_id, p_info):
events = []
for v_id, v_info in p_info.groupby("hadm_id"):
for code, version in zip(v_info["icd_code"], v_info["icd_version"]):
event = Event(
code=code,
table=table,
vocabulary=f"ICD{version}PROC",
visit_id=v_id,
patient_id=p_id,
)
# update patients
events.append(event)
return events
# parallel apply
group_df = group_df.parallel_apply(
lambda x: procedure_unit(x.subject_id.unique()[0], x)
)
# summarize the results
patients = self._add_events_to_patient_dict(patients, group_df)
return patients
[docs] def parse_prescriptions(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses prescriptions table.
Will be called in `self.parse_tables()`
Docs:
- prescriptions: https://mimic.mit.edu/docs/iv/modules/hosp/prescriptions/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
"""
table = "prescriptions"
# read table
df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"),
low_memory=False,
dtype={"subject_id": str, "hadm_id": str, "ndc": str},
)
# drop rows with missing values
df = df.dropna(subset=["subject_id", "hadm_id", "ndc"])
# sort by start date and end date
df = df.sort_values(
["subject_id", "hadm_id", "starttime", "stoptime"], ascending=True
)
# group by patient and visit
group_df = df.groupby("subject_id")
# parallel unit of prescription (per patient)
def prescription_unit(p_id, p_info):
events = []
for v_id, v_info in p_info.groupby("hadm_id"):
for timestamp, code in zip(v_info["starttime"], v_info["ndc"]):
event = Event(
code=code,
table=table,
vocabulary="NDC",
visit_id=v_id,
patient_id=p_id,
timestamp=strptime(timestamp),
)
# update patients
events.append(event)
return events
# parallel apply
group_df = group_df.parallel_apply(
lambda x: prescription_unit(x.subject_id.unique()[0], x)
)
# summarize the results
patients = self._add_events_to_patient_dict(patients, group_df)
return patients
[docs] def parse_labevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses labevents table.
Will be called in `self.parse_tables()`
Docs:
- labevents: https://mimic.mit.edu/docs/iv/modules/hosp/labevents/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
"""
table = "labevents"
# read table
df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"),
dtype={"subject_id": str, "hadm_id": str, "itemid": str},
)
# drop rows with missing values
df = df.dropna(subset=["subject_id", "hadm_id", "itemid"])
# sort by charttime
df = df.sort_values(["subject_id", "hadm_id", "charttime"], ascending=True)
# group by patient and visit
group_df = df.groupby("subject_id")
# parallel unit of labevent (per patient)
def lab_unit(p_id, p_info):
events = []
for v_id, v_info in p_info.groupby("hadm_id"):
for timestamp, code in zip(v_info["charttime"], v_info["itemid"]):
event = Event(
code=code,
table=table,
vocabulary="MIMIC4_ITEMID",
visit_id=v_id,
patient_id=p_id,
timestamp=strptime(timestamp),
)
events.append(event)
return events
# parallel apply
group_df = group_df.parallel_apply(
lambda x: lab_unit(x.subject_id.unique()[0], x)
)
# summarize the results
patients = self._add_events_to_patient_dict(patients, group_df)
return patients
[docs] def parse_hcpcsevents(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
"""Helper function which parses hcpcsevents table.
Will be called in `self.parse_tables()`
Docs:
- hcpcsevents: https://mimic.mit.edu/docs/iv/modules/hosp/hcpcsevents/
Args:
patients: a dict of `Patient` objects indexed by patient_id.
Returns:
The updated patients dict.
Note:
MIMIC-IV does not provide specific timestamps in hcpcsevents
table, so we set it to None.
"""
table = "hcpcsevents"
# read table
df = pd.read_csv(
os.path.join(self.root, f"{table}.csv"),
dtype={"subject_id": str, "hadm_id": str, "hcpcs_cd": str},
)
# drop rows with missing values
df = df.dropna(subset=["subject_id", "hadm_id", "hcpcs_cd"])
# sort by sequence number (i.e., priority)
df = df.sort_values(["subject_id", "hadm_id", "seq_num"], ascending=True)
# group by patient and visit
group_df = df.groupby("subject_id")
# parallel unit of hcpcsevents (per patient)
def hcpcsevents_unit(p_id, p_info):
events = []
for v_id, v_info in p_info.groupby("hadm_id"):
for code in v_info["hcpcs_cd"]:
event = Event(
code=code,
table=table,
vocabulary="MIMIC4_HCPCS_CD",
visit_id=v_id,
patient_id=p_id,
)
# update patients
events.append(event)
return events
# parallel apply
group_df = group_df.parallel_apply(
lambda x: hcpcsevents_unit(x.subject_id.unique()[0], x)
)
# summarize the results
patients = self._add_events_to_patient_dict(patients, group_df)
return patients
if __name__ == "__main__":
dataset = MIMIC4Dataset(
root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp",
tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents", "hcpcsevents"],
code_mapping={"NDC": "ATC"},
refresh_cache=False,
)
dataset.stat()
dataset.info()