Tasks#

We support various real-world healthcare predictive tasks defined by function calls. The following example tasks are collected from top AI/Medical venues, such as:

  1. Drug Recommendation [Yang et al. IJCAI 2021a, Yang et al. IJCAI 2021b, Shang et al. AAAI 2020]

  2. Readmission Prediction [Choi et al. AAAI 2021]

  3. Mortality Prediction [Choi et al. AAAI 2021]

  4. Length of Stay Prediction

  5. Sleep Staging [Yang et al. ArXix 2021]

Getting Started#

New to PyHealth tasks? Start here:

These tutorials demonstrate how to load datasets, apply tasks, train models, and evaluate results.

Understanding Tasks and Processors#

Tasks define what data to extract (via input_schema and output_schema), while processors define how to transform that data into tensors for model training.

After you define a task:

  1. Task execution: The task function extracts relevant features from patient records and generates samples

  2. Processor application: Processors automatically transform these samples into model-ready tensors based on the schemas

Example workflow:

# 1. Define a task with input/output schemas
task = MortalityPredictionMIMIC4()
# input_schema = {"conditions": "sequence", "procedures": "sequence"}
# output_schema = {"mortality": "binary"}

# 2. Apply task to dataset
sample_dataset = base_dataset.set_task(task)

# 3. Processors automatically transform samples:
#    - "sequence" -> SequenceProcessor (converts codes to indices)
#    - "binary" -> BinaryLabelProcessor (converts labels to tensors)

# 4. Get model-ready tensors
sample = sample_dataset[0]
# sample["conditions"] is now a tensor of token indices
# sample["mortality"] is now a binary tensor [0] or [1]

Learn more about processors:

  • See the Processors documentation for details on all available processors

  • Learn about string keys ("sequence", "binary", etc.) that map to specific processors

  • Discover how to customize processor behavior with kwargs tuples

  • Understand processor types for different data modalities (text, images, signals, etc.)

Writing a Custom Task#

When a built-in task doesn’t match your cohort or prediction target, you can define your own by subclassing BaseTask. The class needs three things: a name, input and output schemas, and a __call__ method that processes one patient at a time.

from pyhealth.tasks import BaseTask
from pyhealth.data import Patient
from typing import List, Dict, Any

class MyMortalityTask(BaseTask):
    task_name: str = "MyMortalityTask"

    input_schema: Dict[str, str] = {
        "conditions": "sequence",   # maps to SequenceProcessor
        "procedures": "sequence",
    }
    output_schema: Dict[str, str] = {
        "label": "binary"           # maps to BinaryLabelProcessor
    }

    def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
        samples = []
        for adm in patient.get_events("admissions"):
            label = 1 if adm.hospital_expire_flag == "1" else 0

            # Fetch historical diagnoses up to this admission
            conditions = patient.get_events("diagnoses_icd", end=adm.timestamp)
            cond_codes = [e.icd_code for e in conditions]

            if not cond_codes:
                continue

            samples.append({
                "conditions": [cond_codes],  # wrapped in a list for the sequence processor
                "procedures": [[]],
                "label": label,
            })
        return samples

The __call__ method receives one Patient and returns a list of sample dictionaries. Each dictionary’s keys should match the schemas you declared. Returning an empty list is fine — PyHealth simply skips that patient. Note that event attribute names are always lowercase (e.g. e.icd_code rather than e.ICD_CODE) because PyHealth lowercases all column names at ingest time. Timestamps are accessed through event.timestamp rather than the original column name like charttime, since PyHealth normalises them into a single property.

Processor String Keys#

The string values in your schemas map to specific processor classes. Here is a quick reference:

String key

Processor

Typical use

"sequence"

SequenceProcessor

Diagnosis codes, procedure codes, drug names

"nested_sequence"

NestedSequenceProcessor

Cumulative visit history (drug recommendation, readmission)

"tensor"

TensorProcessor

Aggregated numeric values (e.g. last lab value per item)

"timeseries"

TimeseriesProcessor

Irregular time-series measurements

"multi_hot"

MultiHotProcessor

Demographics, comorbidity flags

"text"

TextProcessor

Clinical notes

"binary"

BinaryLabelProcessor

Binary classification label (0 / 1)

"multiclass"

MultiClassLabelProcessor

Multi-class label

"multilabel"

MultiLabelProcessor

Multi-label classification

"regression"

RegressionLabelProcessor

Continuous regression target

How set_task() Works#

Calling dataset.set_task(task) iterates over every patient in the dataset, runs your __call__ method on each one, fits all the processors on the collected samples, then serialises everything to disk as LitData .ld files. The result is a SampleDataset that supports len() and index access, ready for a DataLoader.

sample_ds = dataset.set_task(MyMortalityTask(), num_workers=4)
len(sample_ds)   # total ML samples across all patients
sample_ds[0]     # a single sample dict with tensor values

If you re-run set_task() with the same task and processor configuration, PyHealth detects the existing cache and skips reprocessing. During development it is useful to set dev=True on the dataset, which limits processing to 1 000 patients so iterations are fast.

Note

A note on multiprocessing. set_task() can spawn worker processes when num_workers > 1. On macOS and Linux this requires the standard Python multiprocessing guard around your top-level script:

if __name__ == '__main__':
    sample_ds = dataset.set_task(task, num_workers=4)

Without this guard, Python may try to re-import and re-run the script in each worker process, leading to infinite recursion. This is a general Python multiprocessing requirement, not specific to PyHealth.

Available Tasks#