Source code for pyhealth.tasks.in_hospital_mortality_mimic4

from datetime import datetime, timedelta
from typing import Any, ClassVar, Dict, List

import polars as pl

from .base_task import BaseTask


[docs]class InHospitalMortalityMIMIC4(BaseTask): """Task for predicting in-hospital mortality using MIMIC-IV dataset. This task leverages lab results to predict the likelihood of in-hospital mortality. Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for input data, which includes: - labs: A timeseries of lab results. output_schema (Dict[str, str]): The schema for output data, which includes: - mortality: A binary indicator of mortality. Examples: >>> from pyhealth.datasets import MIMIC4EHRDataset >>> from pyhealth.tasks import InHospitalMortalityMIMIC4 >>> dataset = MIMIC4EHRDataset( ... root="/path/to/mimic-iv/2.2", ... tables=["labevents"], ... ) >>> task = InHospitalMortalityMIMIC4() >>> samples = dataset.set_task(task) """ task_name: str = "InHospitalMortalityMIMIC4" input_schema: Dict[str, str] = {"labs": "timeseries"} output_schema: Dict[str, str] = {"mortality": "binary"} # Organize lab items by category LAB_CATEGORIES: ClassVar[Dict[str, Dict[str, List[str]]]] = { "Electrolytes & Metabolic": { "Sodium": ["50824", "52455", "50983", "52623"], "Potassium": ["50822", "52452", "50971", "52610"], "Chloride": ["50806", "52434", "50902", "52535"], "Bicarbonate": ["50803", "50804"], "Glucose": ["50809", "52027", "50931", "52569"], "Calcium": ["50808", "51624"], "Magnesium": ["50960"], "Anion Gap": ["50868", "52500"], "Osmolality": ["52031", "50964", "51701"], "Phosphate": ["50970"], }, } # Create flat list of all lab items for use in the function LABITEMS: ClassVar[List[str]] = [ item for category in LAB_CATEGORIES.values() for subcategory in category.values() for item in subcategory ] def __call__(self, patient: Any) -> List[Dict[str, Any]]: input_window_hours = 48 samples = [] demographics = patient.get_events(event_type="patients") assert len(demographics) == 1 demographics = demographics[0] anchor_age = int(demographics.anchor_age) if anchor_age < 18: return [] admissions = patient.get_events(event_type="admissions") for admission in admissions: admission_dischtime = datetime.strptime( admission.dischtime, "%Y-%m-%d %H:%M:%S" ) duration_hour = ( admission_dischtime - admission.timestamp ).total_seconds() / 3600 if duration_hour <= input_window_hours: continue predict_time = admission.timestamp + timedelta(hours=input_window_hours) labevents_df = patient.get_events( event_type="labevents", start=admission.timestamp, end=predict_time, return_df=True, ) labevents_df = labevents_df.filter( pl.col("labevents/itemid").is_in(self.LABITEMS) ) labevents_df = labevents_df.with_columns( pl.col("labevents/storetime").str.strptime( pl.Datetime, "%Y-%m-%d %H:%M:%S" ) ) labevents_df = labevents_df.filter( (pl.col("labevents/storetime") <= predict_time) ) if labevents_df.height == 0: continue labevents_df = labevents_df.select( pl.col("timestamp"), pl.col("labevents/itemid"), pl.col("labevents/valuenum").cast(pl.Float64), ) labevents_df = labevents_df.pivot( index="timestamp", columns="labevents/itemid", values="labevents/valuenum", # in case of multiple values for the same timestamp aggregate_function="first", ) labevents_df = labevents_df.sort("timestamp") # Add missing columns with NaN values existing_cols = set(labevents_df.columns) - {"timestamp"} missing_cols = [item for item in self.LABITEMS if item not in existing_cols] for col in missing_cols: labevents_df = labevents_df.with_columns(pl.lit(None).alias(col)) # Reorder columns by LABITEMS labevents_df = labevents_df.select("timestamp", *self.LABITEMS) timestamps = labevents_df["timestamp"].to_list() lab_values = labevents_df.drop("timestamp").to_numpy() mortality = int(admission.hospital_expire_flag) samples.append( { "patient_id": patient.patient_id, "admission_id": admission.hadm_id, "labs": (timestamps, lab_values), "mortality": mortality, } ) return samples