Source code for pyhealth.tasks.benchmark_ehrshot

from typing import Any, Dict, List, Optional

import polars as pl

from .base_task import BaseTask


[docs]class BenchmarkEHRShot(BaseTask): """Benchmark predictive tasks using EHRShot. Examples: >>> from pyhealth.datasets import EHRShotDataset >>> from pyhealth.tasks import BenchmarkEHRShot >>> dataset = EHRShotDataset( ... root="/path/to/ehrshot/data", ... tables=["ehrshot", "splits", "guo_icu"], ... ) >>> task = BenchmarkEHRShot(task="guo_icu") >>> samples = dataset.set_task(task) """ tasks = { "operational_outcomes": ["guo_los", "guo_readmission", "guo_icu"], "lab_values": [ "lab_thrombocytopenia", "lab_hyperkalemia", "lab_hypoglycemia", "lab_hyponatremia", "lab_anemia", ], "new_diagnoses": [ "new_hypertension", "new_hyperlipidemia", "new_pancan", "new_celiac", "new_lupus", "new_acutemi", ], "chexpert": ["chexpert"], } def __init__(self, task: str, omop_tables: Optional[List[str]] = None) -> None: """Initialize the BenchmarkEHRShot task. Args: task (str): The specific task to benchmark. omop_tables (Optional[List[str]]): List of OMOP tables to filter input events. """ self.task = task self.omop_tables = omop_tables self.task_name = f"BenchmarkEHRShot/{task}" self.input_schema = {"feature": "sequence"} if task in self.tasks["operational_outcomes"]: self.output_schema = {"label": "binary"} elif task in self.tasks["lab_values"]: self.output_schema = {"label": "multiclass"} elif task in self.tasks["new_diagnoses"]: self.output_schema = {"label": "binary"} elif task in self.tasks["chexpert"]: self.output_schema = {"label": "multilabel"}
[docs] def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: if self.omop_tables is None: return df filtered_df = df.filter( (pl.col("event_type") != "ehrshot") | (pl.col("ehrshot/omop_table").is_in(self.omop_tables)) ) return filtered_df
def __call__(self, patient: Any) -> List[Dict[str, Any]]: samples = [] split = patient.get_events("splits") assert len(split) == 1, "Only one split is allowed" split = split[0].split labels = patient.get_events(self.task) for label in labels: # Returning a dataframe of events is much faster than a list of events events_df = patient.get_events( "ehrshot", end=label.timestamp, return_df=True ) codes = events_df["ehrshot/code"].to_list() label_value = label.value if self.task == "chexpert": # Convert {0,1,...,8192} aka binary string to a list of positive label indices label_value = int(label_value) label_value = [i for i in range(14) if (label_value >> i) & 1] label_value = [13 - i for i in label_value[::-1]] samples.append({"feature": codes, "label": label_value, "split": split}) return samples