Source code for pyhealth.tasks.drug_recommendation

from typing import Any, Dict, Iterable, List, Optional

import polars as pl

from pyhealth.data import Patient, Visit
from pyhealth.medcode import CrossMap
from .base_task import BaseTask


_NDC_TO_ATC3_MAPPER = None
_NDC_TO_ATC3_CACHE: Dict[str, List[str]] = {}


def _get_ndc_to_atc3_mapper():
    global _NDC_TO_ATC3_MAPPER
    if _NDC_TO_ATC3_MAPPER is None:
        _NDC_TO_ATC3_MAPPER = CrossMap.load("NDC", "ATC")
    return _NDC_TO_ATC3_MAPPER


def _is_missing_ndc(code: Any) -> bool:
    if code is None:
        return True
    code = str(code).strip()
    return code == "" or code == "0" or code.lower() in {"nan", "none", "<na>"}


def _map_ndc_list_to_atc3(
    ndc_codes: Iterable[Any],
    mapper: Optional[Any] = None,
) -> List[str]:
    """Maps MIMIC prescription NDCs to stable, deduplicated ATC-3 labels."""
    mapper = _get_ndc_to_atc3_mapper() if mapper is None else mapper
    drugs: List[str] = []
    seen = set()

    for ndc in ndc_codes:
        if _is_missing_ndc(ndc):
            continue
        ndc = str(ndc).strip()
        if ndc not in _NDC_TO_ATC3_CACHE:
            _NDC_TO_ATC3_CACHE[ndc] = mapper.map(ndc, target_kwargs={"level": 3})
        mapped_codes = _NDC_TO_ATC3_CACHE[ndc]
        for code in mapped_codes:
            if code is None:
                continue
            code = str(code).strip()
            if code and code not in seen:
                drugs.append(code)
                seen.add(code)

    return drugs


[docs]class DrugRecommendationMIMIC3(BaseTask): """Task for drug recommendation using MIMIC-III dataset. Drug recommendation aims at recommending a set of drugs given the patient health history (e.g., conditions and procedures). This task creates samples with cumulative history, where each visit includes all previous visit information. Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for input data: - conditions: Nested list of diagnosis codes (history + current) - procedures: Nested list of procedure codes (history + current) - drugs_hist: Nested list of drug codes from history (current visit excluded) output_schema (Dict[str, str]): The schema for output data: - drugs: List of drugs to predict for current visit """ task_name: str = "DrugRecommendationMIMIC3" input_schema: Dict[str, str] = { "conditions": "nested_sequence", "procedures": "nested_sequence", "drugs_hist": "nested_sequence", } output_schema: Dict[str, str] = {"drugs": "multilabel"} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cache_version = "ndc_to_atc3_v1" def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Process a patient to create drug recommendation samples. Creates one sample per visit (after first visit) with cumulative history. Each sample includes all previous visits' conditions, procedures, and drugs. Args: patient: Patient object with get_events method Returns: List of samples, each with patient_id, visit_id, conditions history, procedures history, drugs history, and target drugs """ samples = [] # Get all admissions admissions = patient.get_events(event_type="admissions") if len(admissions) < 2: # Need at least 2 visits for history-based prediction return [] # Process each admission for i, admission in enumerate(admissions): # Get diagnosis codes using hadm_id diagnoses_icd = patient.get_events( event_type="diagnoses_icd", filters=[("hadm_id", "==", admission.hadm_id)], return_df=True, ) conditions = ( diagnoses_icd.select(pl.col("diagnoses_icd/icd9_code")) .to_series() .to_list() ) # Get procedure codes using hadm_id procedures_icd = patient.get_events( event_type="procedures_icd", filters=[("hadm_id", "==", admission.hadm_id)], return_df=True, ) procedures = ( procedures_icd.select(pl.col("procedures_icd/icd9_code")) .to_series() .to_list() ) # Get prescriptions using hadm_id prescriptions = patient.get_events( event_type="prescriptions", filters=[("hadm_id", "==", admission.hadm_id)], return_df=True, ) drugs = ( prescriptions.select(pl.col("prescriptions/ndc")).to_series().to_list() ) drugs = _map_ndc_list_to_atc3(drugs) # Exclude visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue samples.append( { "visit_id": admission.hadm_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures, "drugs": drugs, "drugs_hist": drugs, } ) # Exclude patients with less than 2 valid visits if len(samples) < 2: return [] # Add cumulative history for first sample samples[0]["conditions"] = [samples[0]["conditions"]] samples[0]["procedures"] = [samples[0]["procedures"]] samples[0]["drugs_hist"] = [samples[0]["drugs_hist"]] # Add cumulative history for subsequent samples for i in range(1, len(samples)): samples[i]["conditions"] = samples[i - 1]["conditions"] + [ samples[i]["conditions"] ] samples[i]["procedures"] = samples[i - 1]["procedures"] + [ samples[i]["procedures"] ] samples[i]["drugs_hist"] = samples[i - 1]["drugs_hist"] + [ samples[i]["drugs_hist"] ] # Remove target drug from history (set current visit drugs_hist to empty) for i in range(len(samples)): samples[i]["drugs_hist"][i] = [] return samples
[docs]class DrugRecommendationMIMIC4(BaseTask): """Task for drug recommendation using MIMIC-IV dataset. Drug recommendation aims at recommending a set of drugs given the patient health history (e.g., conditions and procedures). This task creates samples with cumulative history, where each visit includes all previous visit information. Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for input data: - conditions: Nested list of diagnosis codes (history + current) - procedures: Nested list of procedure codes (history + current) - drugs_hist: Nested list of drug codes from history (current visit excluded) output_schema (Dict[str, str]): The schema for output data: - drugs: List of drugs to predict for current visit Examples: >>> from pyhealth.datasets import MIMIC4EHRDataset >>> from pyhealth.tasks import DrugRecommendationMIMIC4 >>> dataset = MIMIC4EHRDataset( ... root="/path/to/mimic-iv/2.2", ... tables=["diagnoses_icd", "procedures_icd", "prescriptions"], ... ) >>> task = DrugRecommendationMIMIC4() >>> samples = dataset.set_task(task) """ task_name: str = "DrugRecommendationMIMIC4" input_schema: Dict[str, str] = { "conditions": "nested_sequence", "procedures": "nested_sequence", "drugs_hist": "nested_sequence", } output_schema: Dict[str, str] = {"drugs": "multilabel"} def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cache_version = "ndc_to_atc3_v1" def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Process a patient to create drug recommendation samples. Creates one sample per visit (after first visit) with cumulative history. Each sample includes all previous visits' conditions, procedures, and drugs. Args: patient: Patient object with get_events method Returns: List of samples, each with patient_id, visit_id, conditions history, procedures history, drugs history, and target drugs """ samples = [] # Get all admissions admissions = patient.get_events(event_type="admissions") if len(admissions) < 2: # Need at least 2 visits for history-based prediction return [] # Process each admission for i, admission in enumerate(admissions): # Get diagnosis codes using hadm_id diagnoses_icd = patient.get_events( event_type="diagnoses_icd", filters=[("hadm_id", "==", admission.hadm_id)], return_df=True, ) conditions = ( diagnoses_icd.select( pl.concat_str( ["diagnoses_icd/icd_version", "diagnoses_icd/icd_code"], separator="_", ) ) .to_series() .to_list() ) # Get procedure codes using hadm_id procedures_icd = patient.get_events( event_type="procedures_icd", filters=[("hadm_id", "==", admission.hadm_id)], return_df=True, ) procedures = ( procedures_icd.select( pl.concat_str( ["procedures_icd/icd_version", "procedures_icd/icd_code"], separator="_", ) ) .to_series() .to_list() ) # Get prescriptions using hadm_id prescriptions = patient.get_events( event_type="prescriptions", filters=[("hadm_id", "==", admission.hadm_id)], return_df=True, ) drugs = ( prescriptions.select(pl.col("prescriptions/ndc")).to_series().to_list() ) drugs = _map_ndc_list_to_atc3(drugs) # Exclude visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue samples.append( { "visit_id": admission.hadm_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures, "drugs": drugs, "drugs_hist": drugs, } ) # Exclude patients with less than 2 valid visits if len(samples) < 2: return [] # Add cumulative history for first sample samples[0]["conditions"] = [samples[0]["conditions"]] samples[0]["procedures"] = [samples[0]["procedures"]] samples[0]["drugs_hist"] = [samples[0]["drugs_hist"]] # Add cumulative history for subsequent samples for i in range(1, len(samples)): samples[i]["conditions"] = samples[i - 1]["conditions"] + [ samples[i]["conditions"] ] samples[i]["procedures"] = samples[i - 1]["procedures"] + [ samples[i]["procedures"] ] samples[i]["drugs_hist"] = samples[i - 1]["drugs_hist"] + [ samples[i]["drugs_hist"] ] # Remove target drug from history (set current visit drugs_hist to empty) for i in range(len(samples)): samples[i]["drugs_hist"][i] = [] return samples
[docs]def drug_recommendation_mimic3_fn(patient: Patient): """Processes a single patient for the drug recommendation task. Drug recommendation aims at recommending a set of drugs given the patient health history (e.g., conditions and procedures). Args: patient: a Patient object Returns: samples: a list of samples, each sample is a dict with patient_id, visit_id, and other task-specific attributes as key, like this { "patient_id": xxx, "visit_id": xxx, "conditions": [list of diag in visit 1, list of diag in visit 2, ..., list of diag in visit N], "procedures": [list of prod in visit 1, list of prod in visit 2, ..., list of prod in visit N], "drugs_hist": [list of drug in visit 1, list of drug in visit 2, ..., list of drug in visit (N-1)], "drugs": list of drug in visit N, # this is the predicted target } Examples: >>> from pyhealth.datasets import MIMIC3Dataset >>> mimic3_base = MIMIC3Dataset( ... root="/srv/local/data/physionet.org/files/mimiciii/1.4", ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], ... code_mapping={"ICD9CM": "CCSCM"}, ... ) >>> from pyhealth.tasks import drug_recommendation_mimic3_fn >>> mimic3_sample = mimic3_base.set_task(drug_recommendation_mimic3_fn) >>> mimic3_sample.samples[0] { 'visit_id': '174162', 'patient_id': '107', 'conditions': [['139', '158', '237', '99', '60', '101', '51', '54', '53', '133', '143', '140', '117', '138', '55']], 'procedures': [['4443', '4513', '3995']], 'drugs_hist': [[]], 'drugs': ['0000', '0033', '5817', '0057', '0090', '0053', '0', '0012', '6332', '1001', '6155', '1001', '6332', '0033', '5539', '6332', '5967', '0033', '0040', '5967', '5846', '0016', '5846', '5107', '5551', '6808', '5107', '0090', '5107', '5416', '0033', '1150', '0005', '6365', '0090', '6155', '0005', '0090', '0000', '6373'], } """ samples = [] for i in range(len(patient)): visit: Visit = patient[i] conditions = visit.get_code_list(table="DIAGNOSES_ICD") procedures = visit.get_code_list(table="PROCEDURES_ICD") drugs = visit.get_code_list(table="PRESCRIPTIONS") drugs = _map_ndc_list_to_atc3(drugs) # exclude: visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue # TODO: should also exclude visit with age < 18 samples.append( { "visit_id": visit.visit_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures, "drugs": drugs, "drugs_hist": drugs, } ) # exclude: patients with less than 2 visit if len(samples) < 2: return [] # add history samples[0]["conditions"] = [samples[0]["conditions"]] samples[0]["procedures"] = [samples[0]["procedures"]] samples[0]["drugs_hist"] = [samples[0]["drugs_hist"]] for i in range(1, len(samples)): samples[i]["conditions"] = samples[i - 1]["conditions"] + [ samples[i]["conditions"] ] samples[i]["procedures"] = samples[i - 1]["procedures"] + [ samples[i]["procedures"] ] samples[i]["drugs_hist"] = samples[i - 1]["drugs_hist"] + [ samples[i]["drugs_hist"] ] # remove the target drug from the history for i in range(len(samples)): samples[i]["drugs_hist"][i] = [] return samples
[docs]def drug_recommendation_mimic4_fn(patient: Patient): """Processes a single patient for the drug recommendation task. Drug recommendation aims at recommending a set of drugs given the patient health history (e.g., conditions and procedures). Args: patient: a Patient object Returns: samples: a list of samples, each sample is a dict with patient_id, visit_id, and other task-specific attributes as key { "patient_id": xxx, "visit_id": xxx, "conditions": [list of diag in visit 1, list of diag in visit 2, ..., list of diag in visit N], "procedures": [list of prod in visit 1, list of prod in visit 2, ..., list of prod in visit N], "drugs_hist": [list of drug in visit 1, list of drug in visit 2, ..., list of drug in visit (N-1)], "drugs": list of drug in visit N, # this is the predicted target } Examples: >>> from pyhealth.datasets import MIMIC4Dataset >>> mimic4_base = MIMIC4Dataset( ... root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", ... tables=["diagnoses_icd", "procedures_icd"], ... code_mapping={"ICD10PROC": "CCSPROC"}, ... ) >>> from pyhealth.tasks import drug_recommendation_mimic4_fn >>> mimic4_sample = mimic4_base.set_task(drug_recommendation_mimic4_fn) >>> mimic4_sample.samples[0] [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': [['2', '3', '4']]}] """ samples = [] for i in range(len(patient)): visit: Visit = patient[i] conditions = visit.get_code_list(table="diagnoses_icd") procedures = visit.get_code_list(table="procedures_icd") drugs = visit.get_code_list(table="prescriptions") drugs = _map_ndc_list_to_atc3(drugs) # exclude: visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue # TODO: should also exclude visit with age < 18 samples.append( { "visit_id": visit.visit_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures, "drugs": drugs, "drugs_hist": drugs, } ) # exclude: patients with less than 2 visit if len(samples) < 2: return [] # add history samples[0]["conditions"] = [samples[0]["conditions"]] samples[0]["procedures"] = [samples[0]["procedures"]] samples[0]["drugs_hist"] = [samples[0]["drugs_hist"]] for i in range(1, len(samples)): samples[i]["conditions"] = samples[i - 1]["conditions"] + [ samples[i]["conditions"] ] samples[i]["procedures"] = samples[i - 1]["procedures"] + [ samples[i]["procedures"] ] samples[i]["drugs_hist"] = samples[i - 1]["drugs_hist"] + [ samples[i]["drugs_hist"] ] # remove the target drug from the history for i in range(len(samples)): samples[i]["drugs_hist"][i] = [] return samples
[docs]class DrugRecommendationEICU(BaseTask): """Task for drug recommendation using eICU dataset. Drug recommendation aims at recommending a set of drugs given the patient health history (e.g., conditions and procedures). This task creates samples with cumulative history, where each visit includes all previous visit information. Features key-value pairs: - using diagnosis table as condition codes - using physicalexam table as procedure codes - using medication table as drug codes Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for input data: - conditions: Nested list of diagnosis codes (history + current) - procedures: Nested list of procedure codes (history + current) - drugs_hist: Nested list of drug codes from history (current visit excluded) output_schema (Dict[str, str]): The schema for output data: - drugs: List of drugs to predict for current visit Examples: >>> from pyhealth.datasets import eICUDataset >>> from pyhealth.tasks import DrugRecommendationEICU >>> dataset = eICUDataset( ... root="/path/to/eicu-crd/2.0", ... tables=["diagnosis", "medication", "physicalexam"], ... ) >>> task = DrugRecommendationEICU() >>> sample_dataset = dataset.set_task(task) """ task_name: str = "DrugRecommendationEICU" input_schema: Dict[str, str] = { "conditions": "nested_sequence", "procedures": "nested_sequence", "drugs_hist": "nested_sequence", } output_schema: Dict[str, str] = {"drugs": "multilabel"} def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Process a patient to create drug recommendation samples. Creates one sample per visit (after first visit) with cumulative history. Each sample includes all previous visits' conditions, procedures, and drugs. Args: patient: Patient object with get_events method Returns: List of samples, each with patient_id, visit_id, conditions history, procedures history, drugs history, and target drugs """ samples = [] # Get all patient stays (each row in patient table is an ICU stay) patient_stays = patient.get_events(event_type="patient") if len(patient_stays) < 2: # Need at least 2 visits for history-based prediction return [] # Process each patient stay for stay in patient_stays: # Get the patientunitstayid for filtering stay_id = str(getattr(stay, "patientunitstayid", "")) # Get diagnosis codes using patientunitstayid-based filtering diagnoses = patient.get_events( event_type="diagnosis", filters=[("patientunitstayid", "==", stay_id)] ) conditions = [ getattr(event, "icd9code", "") for event in diagnoses if getattr(event, "icd9code", None) ] # Get physical exam codes physical_exams = patient.get_events( event_type="physicalexam", filters=[("patientunitstayid", "==", stay_id)] ) procedures = [ getattr(event, "physicalexampath", "") for event in physical_exams if getattr(event, "physicalexampath", None) ] # Get medication codes medications = patient.get_events( event_type="medication", filters=[("patientunitstayid", "==", stay_id)] ) drugs = [ getattr(event, "drugname", "") for event in medications if getattr(event, "drugname", None) ] # Exclude visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue samples.append( { "visit_id": stay_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures, "drugs": drugs, "drugs_hist": drugs, } ) # Exclude patients with less than 2 valid visits if len(samples) < 2: return [] # Add cumulative history for first sample samples[0]["conditions"] = [samples[0]["conditions"]] samples[0]["procedures"] = [samples[0]["procedures"]] samples[0]["drugs_hist"] = [samples[0]["drugs_hist"]] # Add cumulative history for subsequent samples for i in range(1, len(samples)): samples[i]["conditions"] = samples[i - 1]["conditions"] + [ samples[i]["conditions"] ] samples[i]["procedures"] = samples[i - 1]["procedures"] + [ samples[i]["procedures"] ] samples[i]["drugs_hist"] = samples[i - 1]["drugs_hist"] + [ samples[i]["drugs_hist"] ] # Remove target drug from history (set current visit drugs_hist to empty) for i in range(len(samples)): samples[i]["drugs_hist"][i] = [] return samples
[docs]def drug_recommendation_omop_fn(patient: Patient): """Processes a single patient for the drug recommendation task. Drug recommendation aims at recommending a set of drugs given the patient health history (e.g., conditions and procedures). Args: patient: a Patient object Returns: samples: a list of samples, each sample is a dict with patient_id, visit_id, and other task-specific attributes as key Examples: >>> from pyhealth.datasets import OMOPDataset >>> omop_base = OMOPDataset( ... root="https://storage.googleapis.com/pyhealth/synpuf1k_omop_cdm_5.2.2", ... tables=["condition_occurrence", "procedure_occurrence"], ... code_mapping={}, ... ) >>> from pyhealth.tasks import drug_recommendation_omop_fn >>> omop_sample = omop_base.set_task(drug_recommendation_omop_fn) >>> omop_sample.samples[0] [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '98', '663', '58', '51'], ['98', '663', '58', '51']], 'procedures': [['1'], ['2', '3']], 'label': [['2', '3', '4'], ['0', '1', '4', '5']]}] """ samples = [] for i in range(len(patient)): visit: Visit = patient[i] conditions = visit.get_code_list(table="condition_occurrence") procedures = visit.get_code_list(table="procedure_occurrence") drugs = visit.get_code_list(table="drug_exposure") # exclude: visits without condition, procedure, or drug code if len(conditions) * len(procedures) * len(drugs) == 0: continue # TODO: should also exclude visit with age < 18 samples.append( { "visit_id": visit.visit_id, "patient_id": patient.patient_id, "conditions": conditions, "procedures": procedures, "drugs": drugs, "drugs_all": drugs, } ) # exclude: patients with less than 2 visit if len(samples) < 2: return [] # add history samples[0]["conditions"] = [samples[0]["conditions"]] samples[0]["procedures"] = [samples[0]["procedures"]] samples[0]["drugs_all"] = [samples[0]["drugs_all"]] for i in range(1, len(samples)): samples[i]["conditions"] = samples[i - 1]["conditions"] + [ samples[i]["conditions"] ] samples[i]["procedures"] = samples[i - 1]["procedures"] + [ samples[i]["procedures"] ] samples[i]["drugs_all"] = samples[i - 1]["drugs_all"] + [ samples[i]["drugs_all"] ] return samples
if __name__ == "__main__": # from pyhealth.datasets import MIMIC3Dataset # base_dataset = MIMIC3Dataset( # root="/srv/local/data/physionet.org/files/mimiciii/1.4", # tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], # dev=True, # code_mapping={"ICD9CM": "CCSCM"}, # refresh_cache=False, # ) # sample_dataset = base_dataset.set_task(task_fn=drug_recommendation_mimic3_fn) # sample_dataset.stat() # print(sample_dataset.available_keys) # print(sample_dataset.samples[0]) from pyhealth.datasets import MIMIC4Dataset base_dataset = MIMIC4Dataset( root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp", tables=["diagnoses_icd", "procedures_icd", "prescriptions"], dev=True, code_mapping={"NDC": "ATC"}, refresh_cache=False, ) sample_dataset = base_dataset.set_task(task_fn=drug_recommendation_mimic4_fn) sample_dataset.stat() print(sample_dataset.available_keys) print(sample_dataset.samples[0]) # from pyhealth.datasets import eICUDataset # from pyhealth.tasks import DrugRecommendationEICU # base_dataset = eICUDataset( # root="/srv/local/data/physionet.org/files/eicu-crd/2.0", # tables=["diagnosis", "medication", "physicalExam"], # dev=True, # refresh_cache=False, # ) # task = DrugRecommendationEICU() # sample_dataset = base_dataset.set_task(task) # sample_dataset.stat() # print(sample_dataset.available_keys) # from pyhealth.datasets import OMOPDataset # base_dataset = OMOPDataset( # root="/srv/local/data/zw12/pyhealth/raw_data/synpuf1k_omop_cdm_5.2.2", # tables=["condition_occurrence", "procedure_occurrence", "drug_exposure"], # dev=True, # refresh_cache=False, # ) # sample_dataset = base_dataset.set_task(task_fn=drug_recommendation_omop_fn) # sample_dataset.stat() # print(sample_dataset.available_keys)