# Description: DKA (Diabetic Ketoacidosis) prediction tasks for MIMIC-IV dataset
import math
from datetime import datetime, timedelta
from typing import Any, ClassVar, Dict, List, Optional, Set, Tuple
import polars as pl
from .base_task import BaseTask
[docs]class DKAPredictionMIMIC4(BaseTask):
"""Task for predicting Diabetic Ketoacidosis (DKA) in the general patient population.
This task creates PATIENT-LEVEL samples from ALL patients in the dataset,
predicting whether they will develop DKA. Features are collected from
admissions BEFORE the first DKA event to prevent data leakage.
Target Population:
- ALL patients in the dataset (no filtering)
- Large pool of negative samples (patients without DKA)
Label Definition:
- Positive (1): Patient has any DKA diagnosis code (ICD-9 or ICD-10)
- Negative (0): Patient has no DKA diagnosis codes
Data Leakage Prevention:
- Admissions are sorted chronologically
- For DKA-positive patients: Only data from admissions BEFORE the
first DKA admission is included (no data from DKA admission or after)
- For DKA-negative patients: All admissions are included
- Patients whose first admission has DKA are excluded (no pre-DKA data)
Features:
- icd_codes: Combined diagnosis + procedure ICD codes (stagenet format)
- labs: 10-dimensional vectors with lab categories
Args:
padding: Additional padding for StageNet processor. Default: 0.
Example:
>>> from pyhealth.datasets import MIMIC4Dataset
>>> from pyhealth.tasks import DKAPredictionMIMIC4
>>>
>>> dataset = MIMIC4Dataset(
... root="/path/to/mimic4",
... tables=["diagnoses_icd", "procedures_icd", "labevents", "admissions"],
... )
>>> task = DKAPredictionMIMIC4()
>>> samples = dataset.set_task(task)
"""
task_name: str = "DKAPredictionMIMIC4"
# ICD-9 codes for Diabetic Ketoacidosis
DKA_ICD9_CODES: ClassVar[Set[str]] = {"25010", "25011", "25012", "25013"}
# ICD-10 prefix for DKA (E10.1x, E11.1x, E13.1x codes cover T1D, T2D, other DKA)
DKA_ICD10_PREFIXES: ClassVar[List[str]] = ["E101", "E111", "E131"]
# Lab categories from mortality_prediction_stagenet_mimic4.py (verified item IDs)
LAB_CATEGORIES: ClassVar[Dict[str, List[str]]] = {
"Sodium": ["50824", "52455", "50983", "52623"],
"Potassium": ["50822", "52452", "50971", "52610"],
"Chloride": ["50806", "52434", "50902", "52535"],
"Bicarbonate": ["50803", "50804"],
"Glucose": ["50809", "52027", "50931", "52569"],
"Calcium": ["50808", "51624"],
"Magnesium": ["50960"],
"Anion Gap": ["50868", "52500"],
"Osmolality": ["52031", "50964", "51701"],
"Phosphate": ["50970"],
}
LAB_CATEGORY_ORDER: ClassVar[List[str]] = [
"Sodium", "Potassium", "Chloride", "Bicarbonate", "Glucose",
"Calcium", "Magnesium", "Anion Gap", "Osmolality", "Phosphate",
]
# Flat list of all lab item IDs for filtering
LABITEMS: ClassVar[List[str]] = [
item for items in LAB_CATEGORIES.values() for item in items
]
def __init__(self, padding: int = 0):
"""Initialize task with optional padding.
Args:
padding: Additional padding for nested sequences. Default: 0.
"""
self.padding = padding
self.input_schema: Dict[str, Tuple[str, Dict[str, Any]]] = { # type: ignore
"icd_codes": ("stagenet", {"padding": padding}),
"labs": ("stagenet_tensor", {}),
}
self.output_schema: Dict[str, str] = {"label": "binary"} # type: ignore
def _is_dka_code(self, code: str, version: Any) -> bool:
"""Check if an ICD code represents Diabetic Ketoacidosis."""
if not code:
return False
normalized = code.replace(".", "").strip().upper()
version_str = str(version) if version is not None else ""
if version_str == "10":
return any(normalized.startswith(p) for p in self.DKA_ICD10_PREFIXES)
if version_str == "9":
return normalized in self.DKA_ICD9_CODES
return False
def _build_lab_vector(self, lab_df: pl.DataFrame) -> List[float]:
"""Build a 10D lab feature vector from lab events DataFrame."""
if lab_df.height == 0:
return [math.nan] * len(self.LAB_CATEGORY_ORDER)
# Filter to relevant lab items and cast
filtered = (
lab_df.with_columns([
pl.col("labevents/itemid").cast(pl.Utf8),
pl.col("labevents/valuenum").cast(pl.Float64),
])
.filter(pl.col("labevents/itemid").is_in(self.LABITEMS))
.filter(pl.col("labevents/valuenum").is_not_null())
)
if filtered.height == 0:
return [math.nan] * len(self.LAB_CATEGORY_ORDER)
# Build vector with one value per category (mean of observed values)
vector: List[float] = []
for category in self.LAB_CATEGORY_ORDER:
itemids = self.LAB_CATEGORIES[category]
cat_df = filtered.filter(pl.col("labevents/itemid").is_in(itemids))
if cat_df.height > 0:
values = cat_df["labevents/valuenum"].drop_nulls()
vector.append(float(values.mean()) if len(values) > 0 else math.nan) # type: ignore
else:
vector.append(math.nan)
return vector
def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Process a patient to create DKA prediction samples.
Iterates through sorted admissions, collecting features until DKA is found.
Label is based on whether DKA occurs in any future admission.
Args:
patient: Patient object with get_events method.
Returns:
List with single sample, or empty list if insufficient data.
"""
# Get admissions and sort by timestamp
admissions = patient.get_events(event_type="admissions")
if not admissions:
return []
# Sort admissions chronologically by timestamp
admissions = sorted(admissions, key=lambda x: x.timestamp)
# Initialize aggregated data structures
all_icd_codes: List[List[str]] = []
all_icd_times: List[float] = []
all_lab_values: List[List[float]] = []
all_lab_times: List[float] = []
previous_admission_time: Optional[datetime] = None
has_dka = False
# Iterate through admissions in chronological order
for admission in admissions:
# Parse admission times
try:
admission_time = admission.timestamp
dischtime_str = getattr(admission, "dischtime", None)
if dischtime_str:
admission_dischtime = datetime.strptime(
dischtime_str, "%Y-%m-%d %H:%M:%S"
)
else:
admission_dischtime = None
except (ValueError, AttributeError):
continue
# Get diagnoses for this admission
diagnoses = patient.get_events(
event_type="diagnoses_icd",
filters=[("hadm_id", "==", admission.hadm_id)],
)
# Iterate through diagnoses - check for DKA and collect codes
visit_codes: List[str] = []
seen: Set[str] = set()
for diag in diagnoses:
code = getattr(diag, "icd_code", None)
version = getattr(diag, "icd_version", None)
if not code:
continue
# Check for DKA - if found, stop everything
if self._is_dka_code(code, version):
has_dka = True
break
# Add diagnosis code if not seen
normalized = f"D_{code.replace('.', '').upper()}"
if normalized not in seen:
seen.add(normalized)
visit_codes.append(normalized)
# If DKA found, don't append this visit's data and stop
if has_dka:
break
# Get procedures for this admission
procedures = patient.get_events(
event_type="procedures_icd",
filters=[("hadm_id", "==", admission.hadm_id)],
)
for proc in procedures:
code = getattr(proc, "icd_code", None)
if not code:
continue
normalized = f"P_{code.replace('.', '').upper()}"
if normalized not in seen:
seen.add(normalized)
visit_codes.append(normalized)
# Calculate time from previous admission (hours)
if previous_admission_time is None:
time_from_previous = 0.0
else:
time_from_previous = (
admission_time - previous_admission_time
).total_seconds() / 3600.0
previous_admission_time = admission_time
# Append this visit's codes
if visit_codes:
all_icd_codes.append(visit_codes)
all_icd_times.append(time_from_previous)
# Get lab events for this admission using hadm_id
lab_df = patient.get_events(
event_type="labevents",
filters=[("hadm_id", "==", admission.hadm_id)],
return_df=True,
)
if lab_df.height > 0:
all_lab_values.append(self._build_lab_vector(lab_df))
all_lab_times.append(time_from_previous)
# Skip if no pre-DKA data (DKA on first visit or no valid admissions)
if not all_icd_codes:
return []
# Ensure we have lab data (use NaN vector if missing)
if not all_lab_values:
all_lab_values = [[math.nan] * len(self.LAB_CATEGORY_ORDER)]
all_lab_times = [0.0]
return [{
"patient_id": patient.patient_id,
"record_id": patient.patient_id,
"icd_codes": (all_icd_times, all_icd_codes),
"labs": (all_lab_times, all_lab_values),
"label": int(has_dka),
}]
class T1DDKAPredictionMIMIC4(BaseTask):
"""Task for predicting Diabetic Ketoacidosis (DKA) in Type 1 Diabetes patients.
This task creates PATIENT-LEVEL samples by identifying patients with Type 1
Diabetes Mellitus (T1DM) and predicting whether they will develop DKA within
a specified time window. Features are collected from admissions BEFORE the
first DKA event to prevent data leakage.
Target Population:
- Patients with Type 1 Diabetes (ICD-9 or ICD-10 codes)
- Excludes patients without any T1DM diagnosis codes
Label Definition:
- Positive (1): Patient has DKA code within 90 days of T1DM diagnosis
- Negative (0): Patient has T1DM but no DKA within the window
Data Leakage Prevention:
- Admissions are sorted chronologically
- For DKA-positive patients: Only data from admissions BEFORE the
first DKA admission is included (no data from DKA admission or after)
- For DKA-negative patients: All admissions are included
- Patients whose first admission has DKA are excluded (no pre-DKA data)
Features:
- icd_codes: Combined diagnosis + procedure ICD codes (stagenet format)
- labs: 10-dimensional vectors with lab categories
Args:
dka_window_days: Number of days to consider for DKA occurrence after
T1DM diagnosis. Default: 90.
padding: Additional padding for StageNet processor. Default: 0.
Example:
>>> from pyhealth.datasets import MIMIC4Dataset
>>> from pyhealth.tasks import T1DDKAPredictionMIMIC4
>>>
>>> dataset = MIMIC4Dataset(
... root="/path/to/mimic4",
... tables=["diagnoses_icd", "procedures_icd", "labevents", "admissions"],
... )
>>> task = T1DDKAPredictionMIMIC4(dka_window_days=90)
>>> samples = dataset.set_task(task)
"""
task_name: str = "T1DDKAPredictionMIMIC4"
# ICD-10 prefix for Type 1 Diabetes Mellitus
T1DM_ICD10_PREFIX: ClassVar[str] = "E10"
# ICD-9 codes for Type 1 Diabetes Mellitus
T1DM_ICD9_CODES: ClassVar[Set[str]] = {
"25001", "25003", "25011", "25013", "25021", "25023",
"25031", "25033", "25041", "25043", "25051", "25053",
"25061", "25063", "25071", "25073", "25081", "25083",
"25091", "25093",
}
# ICD-9 codes for Diabetic Ketoacidosis
DKA_ICD9_CODES: ClassVar[Set[str]] = {"25010", "25011", "25012", "25013"}
# ICD-10 prefix for DKA (E10.1x codes)
DKA_ICD10_PREFIX: ClassVar[str] = "E101"
# Lab categories from mortality_prediction_stagenet_mimic4.py
LAB_CATEGORIES: ClassVar[Dict[str, List[str]]] = {
"Sodium": ["50824", "52455", "50983", "52623"],
"Potassium": ["50822", "52452", "50971", "52610"],
"Chloride": ["50806", "52434", "50902", "52535"],
"Bicarbonate": ["50803", "50804"],
"Glucose": ["50809", "52027", "50931", "52569"],
"Calcium": ["50808", "51624"],
"Magnesium": ["50960"],
"Anion Gap": ["50868", "52500"],
"Osmolality": ["52031", "50964", "51701"],
"Phosphate": ["50970"],
}
LAB_CATEGORY_ORDER: ClassVar[List[str]] = [
"Sodium", "Potassium", "Chloride", "Bicarbonate", "Glucose",
"Calcium", "Magnesium", "Anion Gap", "Osmolality", "Phosphate",
]
LABITEMS: ClassVar[List[str]] = [
item for items in LAB_CATEGORIES.values() for item in items
]
def __init__(self, dka_window_days: int = 90, padding: int = 0):
"""Initialize task with configurable DKA window and padding."""
self.dka_window_days = dka_window_days
self.padding = padding
self.input_schema: Dict[str, Tuple[str, Dict[str, Any]]] = { # type: ignore
"icd_codes": ("stagenet", {"padding": padding}),
"labs": ("stagenet_tensor", {}),
}
self.output_schema: Dict[str, str] = {"label": "binary"} # type: ignore
def _is_t1dm_code(self, code: str | None, version: Any) -> bool:
"""Check if an ICD code represents Type 1 Diabetes Mellitus."""
if not code:
return False
normalized = code.replace(".", "").strip().upper()
version_str = str(version) if version is not None else ""
if version_str == "10":
return normalized.startswith(self.T1DM_ICD10_PREFIX)
if version_str == "9":
return normalized in self.T1DM_ICD9_CODES
return False
def _is_dka_code(self, code: str, version: Any) -> bool:
"""Check if an ICD code represents Diabetic Ketoacidosis."""
if not code:
return False
normalized = code.replace(".", "").strip().upper()
version_str = str(version) if version is not None else ""
if version_str == "10":
return normalized.startswith(self.DKA_ICD10_PREFIX)
if version_str == "9":
return normalized in self.DKA_ICD9_CODES
return False
def _build_lab_vector(self, lab_df: pl.DataFrame) -> List[float]:
"""Build a 10D lab feature vector from lab events DataFrame."""
if lab_df.height == 0:
return [math.nan] * len(self.LAB_CATEGORY_ORDER)
filtered = (
lab_df.with_columns([
pl.col("labevents/itemid").cast(pl.Utf8),
pl.col("labevents/valuenum").cast(pl.Float64),
])
.filter(pl.col("labevents/itemid").is_in(self.LABITEMS))
.filter(pl.col("labevents/valuenum").is_not_null())
)
if filtered.height == 0:
return [math.nan] * len(self.LAB_CATEGORY_ORDER)
vector: List[float] = []
for category in self.LAB_CATEGORY_ORDER:
itemids = self.LAB_CATEGORIES[category]
cat_df = filtered.filter(pl.col("labevents/itemid").is_in(itemids))
if cat_df.height > 0:
values = cat_df["labevents/valuenum"].drop_nulls()
vector.append(float(values.mean()) if len(values) > 0 else math.nan) # type: ignore
else:
vector.append(math.nan)
return vector
def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
"""Filter to keep only patients with Type 1 Diabetes codes."""
if "diagnoses_icd/icd_code" not in df.collect_schema().names():
return df
# Flag rows whose diagnosis code indicates T1DM; window over patient_id to keep full histories
has_t1dm = (
pl.col("diagnoses_icd/icd_code").str.starts_with(self.T1DM_ICD10_PREFIX)
| pl.col("diagnoses_icd/icd_code").is_in(list(self.T1DM_ICD9_CODES))
).fill_null(False)
return (
df.with_columns(has_t1dm.alias("__is_t1dm_code"))
.with_columns(pl.col("__is_t1dm_code").any().over("patient_id").alias("__has_t1dm"))
.filter(pl.col("__has_t1dm"))
.drop(["__is_t1dm_code", "__has_t1dm"])
)
def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Process a patient to create DKA prediction samples.
First checks if patient has T1DM before processing admissions.
Iterates through sorted admissions, collecting features until DKA is found.
Label is based on whether DKA occurs within the time window of T1DM diagnosis.
Args:
patient: Patient object with get_events method.
Returns:
List with single sample, or empty list if patient lacks T1DM or pre-DKA data.
"""
# First check: does this patient have T1DM? (quick scan before expensive ops)
all_diagnoses = patient.get_events(event_type="diagnoses_icd")
if not all_diagnoses:
return []
has_t1dm = False
t1dm_times: List[datetime] = []
for diag in all_diagnoses:
code = getattr(diag, "icd_code", None)
version = getattr(diag, "icd_version", None)
if self._is_t1dm_code(code, version):
has_t1dm = True
diag_time = getattr(diag, "timestamp", None)
if diag_time:
t1dm_times.append(diag_time)
# Skip patients without T1DM diagnosis (early exit before sorting)
if not has_t1dm:
return []
# Get admissions and sort by timestamp
admissions = patient.get_events(event_type="admissions")
if not admissions:
return []
# Sort admissions chronologically by timestamp
admissions = sorted(admissions, key=lambda x: x.timestamp)
# Track earliest T1DM time to enforce the DKA window on history length
window_start = min(t1dm_times) if t1dm_times else None
window_end = (
window_start + timedelta(days=self.dka_window_days)
if window_start is not None
else None
)
# Initialize tracking variables
all_icd_codes: List[List[str]] = []
all_icd_times: List[float] = []
all_lab_values: List[List[float]] = []
all_lab_times: List[float] = []
previous_admission_time: Optional[datetime] = None
has_dka = False
dka_time: Optional[datetime] = None
# Iterate through admissions in chronological order
for admission in admissions:
# Parse admission times
try:
admission_time = admission.timestamp
dischtime_str = getattr(admission, "dischtime", None)
if dischtime_str:
admission_dischtime = datetime.strptime(
dischtime_str, "%Y-%m-%d %H:%M:%S"
)
else:
admission_dischtime = None
except (ValueError, AttributeError):
continue
# Stop once we are past the allowed window to avoid leakage from long histories
if window_end is not None and admission_time > window_end:
break
# Get diagnoses for this admission
diagnoses = patient.get_events(
event_type="diagnoses_icd",
filters=[("hadm_id", "==", admission.hadm_id)],
)
# Iterate through diagnoses - check for DKA and collect codes
visit_codes: List[str] = []
seen: Set[str] = set()
stop_processing = False
for diag in diagnoses:
code = getattr(diag, "icd_code", None)
version = getattr(diag, "icd_version", None)
if not code:
continue
# Check for DKA - if found, record time and stop
if self._is_dka_code(code, version):
candidate_dka_time = getattr(diag, "timestamp", admission_time)
# If DKA occurs outside the window, treat as negative and stop
if window_end is not None and candidate_dka_time > window_end:
stop_processing = True
has_dka = False
dka_time = None
break
has_dka = True
dka_time = candidate_dka_time
break
# Add diagnosis code if not seen
normalized = f"D_{code.replace('.', '').upper()}"
if normalized not in seen:
seen.add(normalized)
visit_codes.append(normalized)
if stop_processing:
break
# If DKA found, don't append this visit's data and stop
if has_dka:
break
# Get procedures for this admission
procedures = patient.get_events(
event_type="procedures_icd",
filters=[("hadm_id", "==", admission.hadm_id)],
)
for proc in procedures:
code = getattr(proc, "icd_code", None)
if not code:
continue
normalized = f"P_{code.replace('.', '').upper()}"
if normalized not in seen:
seen.add(normalized)
visit_codes.append(normalized)
# Calculate time from previous admission (hours)
if previous_admission_time is None:
time_from_previous = 0.0
else:
time_from_previous = (
admission_time - previous_admission_time
).total_seconds() / 3600.0
previous_admission_time = admission_time
# Append this visit's codes
if visit_codes:
all_icd_codes.append(visit_codes)
all_icd_times.append(time_from_previous)
# Get lab events for this admission using hadm_id
lab_df = patient.get_events(
event_type="labevents",
filters=[("hadm_id", "==", admission.hadm_id)],
return_df=True,
)
if lab_df.height > 0:
all_lab_values.append(self._build_lab_vector(lab_df))
all_lab_times.append(time_from_previous)
# Skip if no pre-DKA data
if not all_icd_codes:
return []
# Determine label based on temporal relationship
has_dka_within_window = False
if has_dka and t1dm_times and dka_time:
for t1dm_time in t1dm_times:
delta = abs((dka_time - t1dm_time).days)
if delta <= self.dka_window_days:
has_dka_within_window = True
break
elif has_dka and not t1dm_times:
# Fallback: if no temporal info, use has_dka
has_dka_within_window = True
# Ensure we have lab data
if not all_lab_values:
all_lab_values = [[math.nan] * len(self.LAB_CATEGORY_ORDER)]
all_lab_times = [0.0]
return [{
"patient_id": patient.patient_id,
"record_id": patient.patient_id,
"icd_codes": (all_icd_times, all_icd_codes),
"labs": (all_lab_times, all_lab_values),
"label": int(has_dka_within_window),
}]