Source code for pyhealth.tasks.mortality_prediction

from datetime import datetime
from typing import Any, Dict, List, Optional

from .base_task import BaseTask


[docs]class MortalityPredictionMIMIC3(BaseTask): """Task for predicting mortality using MIMIC-III dataset with text data. This task aims to predict whether the patient will decease in the next hospital visit based on clinical information from the current visit. Examples: >>> from pyhealth.datasets import MIMIC3Dataset >>> from pyhealth.tasks import MortalityPredictionMIMIC3 >>> dataset = MIMIC3Dataset( ... root="/path/to/mimic-iii/1.4", ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], ... ) >>> task = MortalityPredictionMIMIC3() >>> samples = dataset.set_task(task) """ task_name: str = "MortalityPredictionMIMIC3" input_schema: Dict[str, str] = { "conditions": "sequence", "procedures": "sequence", "drugs": "sequence", } output_schema: Dict[str, str] = {"mortality": "binary"} def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task.""" samples = [] # We will drop the last visit visits = patient.get_events(event_type="admissions") if len(visits) <= 1: return [] for i in range(len(visits) - 1): visit = visits[i] next_visit = visits[i + 1] # Check discharge status for mortality label - more robust handling if next_visit.hospital_expire_flag not in [0, 1, "0", "1"]: mortality_label = 0 else: mortality_label = int(next_visit.hospital_expire_flag) # Get clinical codes using hadm_id-based filtering # (more precise than timestamp filtering) diagnoses = patient.get_events( event_type="diagnoses_icd", filters=[("hadm_id", "==", visit.hadm_id)] ) procedures = patient.get_events( event_type="procedures_icd", filters=[("hadm_id", "==", visit.hadm_id)] ) prescriptions = patient.get_events( event_type="prescriptions", filters=[("hadm_id", "==", visit.hadm_id)] ) conditions = [event.icd9_code for event in diagnoses] procedures_list = [event.icd9_code for event in procedures] drugs = [event.ndc for event in prescriptions if event.ndc] # Exclude visits without condition, procedure, or drug code if len(conditions) * len(procedures_list) * len(drugs) == 0: continue samples.append( { "hadm_id": visit.hadm_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures_list, "drugs": drugs, "mortality": mortality_label, } ) return samples
class MultimodalMortalityPredictionMIMIC3(BaseTask): """Task for predicting mortality using MIMIC-III dataset with text data. This task aims to predict whether the patient will decease in the next hospital visit based on clinical information from the current visit. Examples: >>> from pyhealth.datasets import MIMIC3Dataset >>> from pyhealth.tasks import MultimodalMortalityPredictionMIMIC3 >>> dataset = MIMIC3Dataset( ... root="/path/to/mimic-iii/1.4", ... tables=["diagnoses_icd", "procedures_icd", "prescriptions", ... "noteevents"], ... ) >>> task = MultimodalMortalityPredictionMIMIC3() >>> samples = dataset.set_task(task) """ task_name: str = "MultimodalMortalityPredictionMIMIC3" input_schema: Dict[str, str] = { "conditions": "sequence", "procedures": "sequence", "drugs": "sequence", "clinical_notes": "text", # Added support for clinical notes } output_schema: Dict[str, str] = {"mortality": "binary"} def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task.""" samples = [] # We will drop the last visit visits = patient.get_events(event_type="admissions") if len(visits) <= 1: return [] for i in range(len(visits) - 1): visit = visits[i] next_visit = visits[i + 1] # Check discharge status for mortality label - more robust handling if next_visit.hospital_expire_flag not in [0, 1, "0", "1"]: mortality_label = 0 else: mortality_label = int(next_visit.hospital_expire_flag) # Get clinical codes using hadm_id-based filtering # (more precise than timestamp filtering) diagnoses = patient.get_events( event_type="diagnoses_icd", filters=[("hadm_id", "==", visit.hadm_id)] ) procedures = patient.get_events( event_type="procedures_icd", filters=[("hadm_id", "==", visit.hadm_id)] ) prescriptions = patient.get_events( event_type="prescriptions", filters=[("hadm_id", "==", visit.hadm_id)] ) # Get clinical notes notes = patient.get_events( event_type="noteevents", filters=[("hadm_id", "==", visit.hadm_id)] ) conditions = [event.icd9_code for event in diagnoses] procedures_list = [event.icd9_code for event in procedures] drugs = [event.ndc for event in prescriptions if event.ndc] # Extract note text - concatenate if multiple exist text = "" for note in notes: text += note.text # Heterogeneous problem. Some events may not have notes, procedures, prescriptions, or diagnoses. samples.append( { "hadm_id": visit.hadm_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures_list, "drugs": drugs, "clinical_notes": text, "mortality": mortality_label, } ) return samples
[docs]class MortalityPredictionMIMIC4(BaseTask): """Task for predicting mortality using MIMIC-IV EHR data only. Examples: >>> from pyhealth.datasets import MIMIC4EHRDataset >>> from pyhealth.tasks import MortalityPredictionMIMIC4 >>> dataset = MIMIC4EHRDataset( ... root="/path/to/mimic-iv/2.2", ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], ... ) >>> task = MortalityPredictionMIMIC4() >>> samples = dataset.set_task(task) """ task_name: str = "MortalityPredictionMIMIC4" input_schema: Dict[str, str] = { "conditions": "sequence", "procedures": "sequence", "drugs": "sequence", } output_schema: Dict[str, str] = {"mortality": "binary"} def _clean_sequence(self, sequence: Optional[List[Any]]) -> List[str]: """ Clean a sequence by: 1. Removing None values 2. Converting to strings 3. Removing empty strings """ if sequence is None: return [] # Remove None, convert to strings, remove empty strings cleaned = [ str(item).strip() for item in sequence if item is not None and str(item).strip() ] return cleaned def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task.""" samples = [] # Get demographic info to filter by age demographics = patient.get_events(event_type="patients") if not demographics: return [] demographics = demographics[0] anchor_age = getattr(demographics, "anchor_age", None) # Safely check age - fix potential bug with non-numeric ages try: if anchor_age is not None and int(float(anchor_age)) < 18: return [] # Skip patients under 18 except (ValueError, TypeError): # If age can't be determined, we'll include the patient pass # Get visits admissions = patient.get_events(event_type="admissions") if len(admissions) <= 1: return [] for i in range(len(admissions) - 1): admission = admissions[i] next_admission = admissions[i + 1] # Check discharge status for mortality label - more robust handling if next_admission.hospital_expire_flag not in [0, 1, "0", "1"]: mortality_label = 0 else: mortality_label = int(next_admission.hospital_expire_flag) # Parse admission timestamps try: admission_dischtime = datetime.strptime( admission.dischtime, "%Y-%m-%d %H:%M:%S" ) except (ValueError, AttributeError): # If date parsing fails, skip this admission print("Error parsing admission discharge time:", admission.dischtime) continue # Get clinical codes diagnoses_icd = patient.get_events( event_type="diagnoses_icd", start=admission.timestamp, end=admission_dischtime, ) procedures_icd = patient.get_events( event_type="procedures_icd", start=admission.timestamp, end=admission_dischtime, ) prescriptions = patient.get_events( event_type="prescriptions", start=admission.timestamp, end=admission_dischtime, ) # Extract relevant data conditions = self._clean_sequence( [getattr(event, "icd_code", None) for event in diagnoses_icd] ) procedures_list = self._clean_sequence( [getattr(event, "icd_code", None) for event in procedures_icd] ) drugs = self._clean_sequence( [getattr(event, "ndc", None) for event in prescriptions] ) # Exclude visits without condition, procedure, or drug code if len(conditions) * len(procedures_list) * len(drugs) == 0: continue samples.append( { "visit_id": admission.hadm_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures_list, "drugs": drugs, "mortality": mortality_label, } ) return samples
class MultimodalMortalityPredictionMIMIC4(BaseTask): """Task for predicting patient-level mortality using MIMIC-IV multimodal data. This task combines multiple modalities for mortality prediction at the PATIENT LEVEL (not visit level). All core modalities are required for each sample. Required Modalities: - EHR codes: ICD diagnoses, procedures, AND prescriptions (all required) - Clinical notes: Discharge summaries OR radiology reports (at least one) - Lab events: 10-dimensional lab value vectors (time-series) - Chest X-rays: Must have an image path available Examples: >>> from pyhealth.datasets import MIMIC4Dataset >>> from pyhealth.tasks import MultimodalMortalityPredictionMIMIC4 >>> dataset = MIMIC4Dataset( ... ehr_root="/path/to/mimic-iv/2.2", ... note_root="/path/to/mimic-iv-note/2.2", ... cxr_root="/path/to/mimic-cxr/2.0.0", ... ehr_tables=["diagnoses_icd", "procedures_icd", ... "prescriptions", "labevents"], ... note_tables=["discharge", "radiology"], ... cxr_tables=["metadata", "negbio"], ... ) >>> task = MultimodalMortalityPredictionMIMIC4() >>> samples = dataset.set_task(task) Patient-Level Aggregation: - Mortality is determined iteratively by checking if the NEXT admission has the death flag - Admissions are included up to (but not including) any admission where the patient dies - For surviving patients: aggregate all events across all admissions - Returns ONE sample per patient with aggregated multimodal data Modality Coverage: - No modality requirements - returns all patients - Coverage analysis should be done downstream - Discharge and radiology notes are returned as lists (raw processor) - lab_values uses nested_sequence_floats processor for 10D vectors - lab_times is a separate list of time offsets (raw processor) Lab Processing: - 10-dimensional vectors (one per lab category) - Categories: Sodium, Potassium, Chloride, Bicarbonate, Glucose, Calcium, Magnesium, Anion Gap, Osmolality, Phosphate - Multiple itemids per category → take first observed value - Time intervals calculated from first admission start (hours) Image Processing: - Uses image_path from MIMIC-CXR metadata directly - Returns first available X-ray image path across all X-rays """ task_name: str = "MultimodalMortalityPredictionMIMIC4" # Lab categories matching MortalityPredictionStageNetMIMIC4 LAB_CATEGORIES: Dict[str, List[str]] = { "Sodium": ["50824", "52455", "50983", "52623"], "Potassium": ["50822", "52452", "50971", "52610"], "Chloride": ["50806", "52434", "50902", "52535"], "Bicarbonate": ["50803", "50804"], "Glucose": ["50809", "52027", "50931", "52569"], "Calcium": ["50808", "51624"], "Magnesium": ["50960"], "Anion Gap": ["50868", "52500"], "Osmolality": ["52031", "50964", "51701"], "Phosphate": ["50970"], } LAB_CATEGORY_NAMES: List[str] = [ "Sodium", "Potassium", "Chloride", "Bicarbonate", "Glucose", "Calcium", "Magnesium", "Anion Gap", "Osmolality", "Phosphate", ] LABITEMS: List[str] = [ item for itemids in LAB_CATEGORIES.values() for item in itemids ] def __init__(self, **kwargs): """Initialize the multimodal mortality prediction task. Args: **kwargs: Passed to :class:`~pyhealth.tasks.BaseTask`, e.g. ``code_mapping``. """ self.input_schema: Dict[str, str] = { "conditions": "nested_sequence", # Nested by visit "procedures": "nested_sequence", # Nested by visit "drugs": "nested_sequence", # Nested by visit "discharge": "raw", # List of discharge notes "radiology": "raw", # List of radiology notes "lab_values": "nested_sequence_floats", # 10D lab vectors per timestamp "lab_times": "raw", # Lab measurement times (hours from first admission) "negbio_findings": "sequence", # NegBio X-ray findings "image_path": "text", # Image path as text string } self.output_schema: Dict[str, str] = {"mortality": "binary"} super().__init__(**kwargs) def _clean_sequence(self, sequence: Optional[List[Any]]) -> List[str]: """Clean a sequence by removing None values and converting to strings.""" if sequence is None: return [] # Remove None, convert to strings, remove empty strings cleaned = [ str(item).strip() for item in sequence if item is not None and str(item).strip() ] return cleaned def _clean_text(self, text: Optional[str]) -> Optional[str]: """Return text if non-empty, otherwise None.""" return text if text else None def _process_lab_events( self, patient: Any, admission_time: datetime, admission_dischtime: datetime, reference_time: Optional[datetime] = None, ) -> Optional[tuple]: """Process lab events into 10-dimensional vectors with timestamps. Args: patient: Patient object admission_time: Admission start time admission_dischtime: Admission discharge time reference_time: Reference time for calculating time offsets (default: admission_time) Returns: Tuple of (times_list, values_list) or None if no lab events """ try: import polars as pl except ImportError: return None if reference_time is None: reference_time = admission_time # Use timestamp filtering for lab events (hadm_id not reliable) labevents_df = patient.get_events( event_type="labevents", start=admission_time, end=admission_dischtime, return_df=True, ) if labevents_df is None or labevents_df.height == 0: return None # Filter to relevant lab items labevents_df = labevents_df.filter( pl.col("labevents/itemid").is_in(self.LABITEMS) ) if labevents_df.height == 0: return None # Parse storetime and filter (matching stagenet implementation) labevents_df = labevents_df.with_columns( pl.col("labevents/storetime").str.strptime(pl.Datetime, "%Y-%m-%d %H:%M:%S") ) labevents_df = labevents_df.filter( pl.col("labevents/storetime") <= admission_dischtime ) if labevents_df.height == 0: return None # Select relevant columns labevents_df = labevents_df.select( pl.col("timestamp"), pl.col("labevents/itemid"), pl.col("labevents/valuenum").cast(pl.Float64), ) # Group by timestamp and aggregate into 10D vectors unique_timestamps = sorted(labevents_df["timestamp"].unique().to_list()) lab_times = [] lab_values = [] for lab_ts in unique_timestamps: ts_labs = labevents_df.filter(pl.col("timestamp") == lab_ts) # Create 10-dimensional vector lab_vector = [] for category_name in self.LAB_CATEGORY_NAMES: category_itemids = self.LAB_CATEGORIES[category_name] # Find first matching value for this category category_value = None for itemid in category_itemids: matching = ts_labs.filter(pl.col("labevents/itemid") == itemid) if matching.height > 0: category_value = matching["labevents/valuenum"][0] break lab_vector.append(category_value) # Calculate time from reference time (hours) time_from_reference = (lab_ts - reference_time).total_seconds() / 3600.0 lab_times.append(time_from_reference) lab_values.append(lab_vector) if len(lab_values) == 0: return None return (lab_times, lab_values) def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for patient-level multimodal mortality prediction. This task aggregates ALL modalities across visits at the patient level, supporting heterogeneous features (not all modalities required). Mortality is determined iteratively by checking if the NEXT admission has the death flag. Admissions are included up to (but not including) any admission where the patient dies. Returns ONE sample per patient with aggregated multimodal data. """ # Get demographic info to filter by age demographics = patient.get_events(event_type="patients") if not demographics: return [] demographics = demographics[0] # Get visits admissions = patient.get_events(event_type="admissions") if len(admissions) == 0: return [] # Determine which admissions to process iteratively # Check each admission's NEXT admission for mortality flag admissions_to_process = [] mortality_label = 0 for i, admission in enumerate(admissions): # Check if THIS admission has the death flag if admission.hospital_expire_flag in [1, "1"]: # Patient died in this admission - set mortality label # but don't include this admission's data mortality_label = 1 break # Check if there's a next admission with death flag if i + 1 < len(admissions): next_admission = admissions[i + 1] if next_admission.hospital_expire_flag in [1, "1"]: # Next admission has death - include current, set mortality admissions_to_process.append(admission) mortality_label = 1 break # No death in current or next - include this admission admissions_to_process.append(admission) if len(admissions_to_process) == 0: return [] # Get first admission time as reference for lab time calculations first_admission_time = admissions_to_process[0].timestamp # Aggregated data across all admissions all_conditions = [] all_procedures = [] all_drugs = [] all_discharge_notes = [] # List of individual discharge notes all_radiology_notes = [] # List of individual radiology notes all_lab_times = [] all_lab_values = [] all_negbio_findings = [] image_path = "" # Empty string instead of None for serialization # Get X-ray data (patient-level, not admission-specific) # Note: event types match table names in mimic4_cxr.yaml (negbio, metadata) negbio_events = patient.get_events(event_type="negbio") metadata_events = patient.get_events(event_type="metadata") # Process X-ray findings (aggregate across all X-rays) # NegBio findings attributes (from mimic4_cxr.yaml negbio table) negbio_finding_names = [ "no finding", "enlarged cardiomediastinum", "cardiomegaly", "lung opacity", "lung lesion", "edema", "consolidation", "pneumonia", "atelectasis", "pneumothorax", "pleural effusion", "pleural other", "fracture", "support devices", ] for xray in negbio_events: try: for finding_name in negbio_finding_names: try: value = getattr(xray, finding_name, None) if value is not None and float(value) > 0: all_negbio_findings.append(finding_name) except (ValueError, TypeError, AttributeError): pass except Exception: pass # Get first available image path from metadata for event in metadata_events: try: if event.image_path: image_path = event.image_path break # Use first valid image except AttributeError: pass # Process each admission and aggregate data for admission in admissions_to_process: # Parse admission discharge time for lab events filtering try: admission_dischtime = datetime.strptime( admission.dischtime, "%Y-%m-%d %H:%M:%S" ) except (ValueError, AttributeError): # If we can't parse discharge time, skip this admission continue # Skip if discharge is before admission (data quality issue) if admission_dischtime < admission.timestamp: continue # Get clinical codes using hadm_id filtering diagnoses_icd = patient.get_events( event_type="diagnoses_icd", filters=[("hadm_id", "==", admission.hadm_id)], ) procedures_icd = patient.get_events( event_type="procedures_icd", filters=[("hadm_id", "==", admission.hadm_id)], ) prescriptions = patient.get_events( event_type="prescriptions", filters=[("hadm_id", "==", admission.hadm_id)], ) # Get notes using hadm_id filtering discharge_notes = patient.get_events( event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)] ) radiology_notes = patient.get_events( event_type="radiology", filters=[("hadm_id", "==", admission.hadm_id)] ) # Extract clinical codes per visit (nested structure) conditions = self._clean_sequence( [event.icd_code for event in diagnoses_icd] ) procedures_list = self._clean_sequence( [event.icd_code for event in procedures_icd] ) drugs = self._clean_sequence([event.ndc for event in prescriptions]) # Append as nested lists (one list per visit) for nested_sequence all_conditions.append(conditions) all_procedures.append(procedures_list) all_drugs.append(drugs) # Extract and aggregate notes as individual items in lists # Note: attribute is "text" (from mimic4_note.yaml), not "discharge"/"radiology" for note in discharge_notes: try: note_text = self._clean_text(note.text) if note_text: all_discharge_notes.append(note_text) except AttributeError: pass for note in radiology_notes: try: note_text = self._clean_text(note.text) if note_text: all_radiology_notes.append(note_text) except AttributeError: pass # Process lab events with reference to first admission time labs_data = self._process_lab_events( patient, admission.timestamp, admission_dischtime, reference_time=first_admission_time, ) if labs_data is not None: lab_times, lab_values = labs_data all_lab_times.extend(lab_times) all_lab_values.extend(lab_values) # ===== MODALITY REQUIREMENTS ===== # Check that all required modalities are present before returning sample # Required: EHR codes (conditions, procedures, drugs), notes (discharge OR radiology), # labs, and image_path # Check EHR codes - need at least one code in each category across all visits has_conditions = any(len(codes) > 0 for codes in all_conditions) has_procedures = any(len(codes) > 0 for codes in all_procedures) has_drugs = any(len(codes) > 0 for codes in all_drugs) # Check notes - need at least one discharge OR radiology note has_notes = len(all_discharge_notes) > 0 or len(all_radiology_notes) > 0 # Check labs - need at least one lab measurement has_labs = len(all_lab_times) > 0 # Check image - need a valid image path has_image = bool(image_path) # Return empty list if any required modality is missing if not ( has_conditions and has_procedures and has_drugs and has_notes and has_labs and has_image ): return [] # Sort lab events by time and create aggregated labs data # Use nested_sequence_floats processor for lab_values (handles None values) if all_lab_times: sorted_indices = sorted( range(len(all_lab_times)), key=lambda k: all_lab_times[k] ) sorted_lab_times = [all_lab_times[i] for i in sorted_indices] sorted_lab_values = [all_lab_values[i] for i in sorted_indices] else: sorted_lab_times = [] sorted_lab_values = [] # Deduplicate negbio findings (flat sequence) unique_negbio = list(dict.fromkeys(all_negbio_findings)) # Return single patient-level sample with heterogeneous features # Note: conditions/procedures/drugs are nested lists (one list per visit) # Note: discharge and radiology are lists (passed through by raw processor) # Note: lab_values uses nested_sequence_floats processor (handles None values) return [ { "patient_id": patient.patient_id, "conditions": all_conditions, # Nested: [[visit1_codes], [visit2_codes], ...] "procedures": all_procedures, # Nested: [[visit1_codes], [visit2_codes], ...] "drugs": all_drugs, # Nested: [[visit1_codes], [visit2_codes], ...] "discharge": all_discharge_notes, # List of discharge notes "radiology": all_radiology_notes, # List of radiology notes "lab_values": sorted_lab_values, # Nested floats: [[10D vector], ...] "lab_times": sorted_lab_times, # List of times (hours from first admission) "negbio_findings": unique_negbio, # NegBio X-ray findings "image_path": image_path, # Image path as string "mortality": mortality_label, } ]
[docs]class MortalityPredictionEICU(BaseTask): """Task for predicting mortality using eICU dataset. This task aims to predict whether the patient will decease in the next hospital visit based on clinical information from the current visit. Features key-value pairs: - using diagnosis table (ICD9CM and ICD10CM) as condition codes - using physicalexam table as procedure codes - using medication table as drugs codes Examples: >>> from pyhealth.datasets import eICUDataset >>> from pyhealth.tasks import MortalityPredictionEICU >>> dataset = eICUDataset( ... root="/path/to/eicu-crd/2.0", ... tables=["diagnosis", "medication", "physicalExam"], ... ) >>> task = MortalityPredictionEICU() >>> samples = dataset.set_task(task) """ task_name: str = "MortalityPredictionEICU" input_schema: Dict[str, str] = { "conditions": "sequence", "procedures": "sequence", "drugs": "sequence", } output_schema: Dict[str, str] = {"mortality": "binary"} def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task. Args: patient (Any): A Patient object containing patient data. Returns: List[Dict[str, Any]]: A list of samples, each sample is a dict with patient_id, visit_id, conditions, procedures, drugs and mortality. """ samples = [] # Get patient stays (each row in patient table is an ICU stay) patient_stays = patient.get_events(event_type="patient") if len(patient_stays) <= 1: return [] for i in range(len(patient_stays) - 1): stay = patient_stays[i] next_stay = patient_stays[i + 1] # Check discharge status for mortality label # In eICU, hospitaldischargestatus indicates "Alive" or "Expired" discharge_status = getattr(next_stay, "hospitaldischargestatus", None) if discharge_status not in ["Alive", "Expired"]: mortality_label = 0 else: mortality_label = 0 if discharge_status == "Alive" else 1 # Get the patientunitstayid for filtering stay_id = str(getattr(stay, "patientunitstayid", "")) # Get clinical codes using patientunitstayid-based filtering diagnoses = patient.get_events( event_type="diagnosis", filters=[("patientunitstayid", "==", stay_id)] ) physical_exams = patient.get_events( event_type="physicalexam", filters=[("patientunitstayid", "==", stay_id)] ) medications = patient.get_events( event_type="medication", filters=[("patientunitstayid", "==", stay_id)] ) # Extract codes - use icd9code for diagnoses, physicalexampath for exams, drugname for meds conditions = [ getattr(event, "icd9code", "") for event in diagnoses if getattr(event, "icd9code", None) ] procedures_list = [ getattr(event, "physicalexampath", "") for event in physical_exams if getattr(event, "physicalexampath", None) ] drugs = [ getattr(event, "drugname", "") for event in medications if getattr(event, "drugname", None) ] # Exclude visits without condition, procedure, or drug code if len(conditions) * len(procedures_list) * len(drugs) == 0: continue # TODO: Exclude visits with age < 18 samples.append( { "visit_id": stay_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures_list, "drugs": drugs, "mortality": mortality_label, } ) return samples
[docs]class MortalityPredictionEICU2(BaseTask): """Task for predicting mortality using eICU dataset with alternative coding. This task aims to predict whether the patient will decease in the next hospital visit based on clinical information from the current visit. Similar to MortalityPredictionEICU, but with different code mapping: - using admissiondx table and diagnosisstring under diagnosis table as condition codes - using treatment table as procedure codes Examples: >>> from pyhealth.datasets import eICUDataset >>> from pyhealth.tasks import MortalityPredictionEICU2 >>> dataset = eICUDataset( ... root="/path/to/eicu-crd/2.0", ... tables=["diagnosis", "treatment", "admissionDx"], ... ) >>> task = MortalityPredictionEICU2() >>> samples = dataset.set_task(task) """ task_name: str = "MortalityPredictionEICU2" input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence"} output_schema: Dict[str, str] = {"mortality": "binary"} def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task. Args: patient (Any): A Patient object containing patient data. Returns: List[Dict[str, Any]]: A list of samples, each sample is a dict with patient_id, visit_id, conditions, procedures and mortality. """ samples = [] # Get patient stays (each row in patient table is an ICU stay) patient_stays = patient.get_events(event_type="patient") if len(patient_stays) <= 1: return [] for i in range(len(patient_stays) - 1): stay = patient_stays[i] next_stay = patient_stays[i + 1] # Check discharge status for mortality label discharge_status = getattr(next_stay, "hospitaldischargestatus", None) if discharge_status not in ["Alive", "Expired"]: mortality_label = 0 else: mortality_label = 0 if discharge_status == "Alive" else 1 # Get the patientunitstayid for filtering stay_id = str(getattr(stay, "patientunitstayid", "")) # Get clinical codes using patientunitstayid-based filtering admission_dx = patient.get_events( event_type="admissiondx", filters=[("patientunitstayid", "==", stay_id)] ) diagnosis_events = patient.get_events( event_type="diagnosis", filters=[("patientunitstayid", "==", stay_id)] ) treatments = patient.get_events( event_type="treatment", filters=[("patientunitstayid", "==", stay_id)] ) # Get diagnosis strings from diagnosis events diagnosis_strings = list( set( [ getattr(event, "diagnosisstring", "") for event in diagnosis_events if getattr(event, "diagnosisstring", None) ] ) ) # Get admission diagnosis codes admission_dx_codes = [ getattr(event, "admitdxpath", "") for event in admission_dx if getattr(event, "admitdxpath", None) ] # Get treatment codes treatment_codes = [ getattr(event, "treatmentstring", "") for event in treatments if getattr(event, "treatmentstring", None) ] # Combine admission diagnoses and diagnosis strings conditions = admission_dx_codes + diagnosis_strings # Exclude visits without sufficient codes if len(conditions) * len(treatment_codes) == 0: continue # TODO: Exclude visits with age < 18 samples.append( { "visit_id": stay_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": treatment_codes, "mortality": mortality_label, } ) return samples
[docs]class MortalityPredictionOMOP(BaseTask): """Task for predicting mortality using OMOP CDM dataset. This task predicts whether a patient has a death record (binary mortality prediction) based on clinical information from each visit. Unlike visit-specific mortality tasks, this provides a patient-level mortality indicator (whether the patient died at any point). The task processes visits sequentially and extracts clinical codes (conditions, procedures, drugs) for each visit. Clinical events are linked to visits via visit_occurrence_id, following OMOP CDM conventions. Features: - Uses OMOP CDM standard tables (condition_occurrence, procedure_occurrence, drug_exposure) - Links clinical events to visits via visit_occurrence_id - Uses OMOP concept_ids as medical codes - Binary mortality label (1 if patient has death record, 0 otherwise) Task Schema: Input: - conditions: sequence of condition_concept_id codes - procedures: sequence of procedure_concept_id codes - drugs: sequence of drug_concept_id codes Output: - mortality: binary label (0: no death record, 1: death record) Args: patient (Patient): A Patient object containing OMOP CDM data. Returns: List[Dict[str, Any]]: A list of samples, where each sample contains: - visit_id: The visit_occurrence_id - patient_id: The person_id - conditions: List of condition_concept_id codes - procedures: List of procedure_concept_id codes - drugs: List of drug_concept_id codes - mortality: Binary label (0 or 1) Examples: >>> from pyhealth.datasets import OMOPDataset >>> from pyhealth.tasks import MortalityPredictionOMOP >>> >>> # Load OMOP dataset >>> dataset = OMOPDataset( ... root="/path/to/omop/data", ... tables=["condition_occurrence", "procedure_occurrence", ... "drug_exposure"], ... ) >>> >>> # Create mortality prediction task >>> task = MortalityPredictionOMOP() >>> sample_dataset = dataset.set_task(task=task) >>> >>> # Access samples >>> print(f"Generated {len(sample_dataset)} samples") >>> sample = sample_dataset.samples[0] >>> print(f"Conditions: {sample['conditions']}") >>> print(f"Mortality: {sample['mortality']}") Note: - Visits without any clinical codes (conditions, procedures, or drugs) are excluded - The last visit is excluded as there is no "next visit" to predict for - Clinical events are filtered by visit_occurrence_id, not by timestamp ranges, following OMOP best practices """ task_name: str = "MortalityPredictionOMOP" input_schema: Dict[str, str] = { "conditions": "sequence", "procedures": "sequence", "drugs": "sequence", } output_schema: Dict[str, str] = {"mortality": "binary"} def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Processes a single patient for the mortality prediction task. Args: patient (Any): A Patient object containing patient data. Returns: List[Dict[str, Any]]: A list of samples, each sample is a dict with patient_id, visit_id, conditions, procedures, drugs and mortality. """ samples = [] # Get visits and death events visits = patient.get_events(event_type="visit_occurrence") death_events = patient.get_events(event_type="death") if len(visits) <= 1: return [] # Get death datetime if exists death_datetime = None if len(death_events) > 0: death_datetime = death_events[0].timestamp for i in range(len(visits) - 1): visit = visits[i] next_visit = visits[i + 1] # Determine mortality label # Check if patient has death record (died at some point) # This is binary mortality prediction, not time-specific mortality_label = 1 if death_datetime is not None else 0 # Get visit end datetime for filtering events visit_end_str = getattr(visit, "visit_end_datetime", None) # Parse visit_end_datetime if it's a string visit_end = None if visit_end_str is not None: if isinstance(visit_end_str, str): try: visit_end = datetime.strptime( visit_end_str, "%Y-%m-%d %H:%M:%S" ) except (ValueError, TypeError): visit_end = None else: visit_end = visit_end_str # Fallback to next visit start if visit_end not available if visit_end is None: visit_end = next_visit.timestamp # Get visit_occurrence_id for filtering visit_occurrence_id = str(getattr(visit, "visit_occurrence_id", None)) # Get clinical codes within this visit using visit_occurrence_id # In OMOP, clinical events are linked to visits by # visit_occurrence_id if visit_occurrence_id: conditions = patient.get_events( event_type="condition_occurrence", filters=[("visit_occurrence_id", "==", visit_occurrence_id)], ) procedures = patient.get_events( event_type="procedure_occurrence", filters=[("visit_occurrence_id", "==", visit_occurrence_id)], ) drugs = patient.get_events( event_type="drug_exposure", filters=[("visit_occurrence_id", "==", visit_occurrence_id)], ) # Extract concept IDs as codes condition_codes = [ str(getattr(event, "condition_concept_id", "")) for event in conditions if getattr(event, "condition_concept_id", None) is not None ] procedure_codes = [ str(getattr(event, "procedure_concept_id", "")) for event in procedures if getattr(event, "procedure_concept_id", None) is not None ] drug_codes = [ str(getattr(event, "drug_concept_id", "")) for event in drugs if getattr(event, "drug_concept_id", None) is not None ] # Exclude visits without any clinical codes total_codes = len(condition_codes) + len(procedure_codes) + len(drug_codes) if total_codes == 0: continue samples.append( { "visit_id": visit_occurrence_id, "patient_id": patient.patient_id, "conditions": condition_codes, "procedures": procedure_codes, "drugs": drug_codes, "mortality": mortality_label, } ) return samples