Source code for pyhealth.tasks.chestxray14_binary_classification

"""
PyHealth task for binary classification using the ChestX-ray14 dataset.

Dataset link:
    https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/36938765345

Dataset paper: (please cite if you use this dataset)
    Xiaosong Wang, Yifan Peng, Le Lu, et al. "ChestX-ray8: Hospital-scale Chest
    X-ray Database and Benchmarks on Weakly-Supervised Classification and
    Localization of Common Thorax Diseases." 2017 IEEE Conference on Computer
    Vision and Pattern Recognition (CVPR), pp. 3462-3471.

Dataset paper link:
    https://arxiv.org/abs/1705.02315

Author:
    Eric Schrock (ejs9@illinois.edu)
"""

import logging
from typing import Dict, List

from pyhealth.data import Event, Patient
from pyhealth.tasks import BaseTask

logger = logging.getLogger(__name__)


[docs]class ChestXray14BinaryClassification(BaseTask): """ A PyHealth task class for binary classification of a specific disease in the ChestXray14 dataset. Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for the task input. output_schema (Dict[str, str]): The schema for the task output. disease (str): The disease label to classify. Examples: >>> from pyhealth.datasets import ChestXray14Dataset >>> from pyhealth.tasks import ChestXray14BinaryClassification >>> dataset = ChestXray14Dataset(root="/path/to/chestxray14") >>> task = ChestXray14BinaryClassification(disease="pneumonia") >>> samples = dataset.set_task(task) """ task_name: str = "ChestXray14BinaryClassification" input_schema: Dict[str, str] = {"image": "image"} output_schema: Dict[str, str] = {"label": "binary"} def __init__(self, disease: str) -> None: """ Initializes the ChestXray14BinaryClassification task with a specified disease. Args: disease (str): The disease to classify in the binary task. Must be one of the predefined class labels in ChestXray14Dataset. Raises: ValueError: If the specified disease is not a valid class in the dataset. """ from pyhealth.datasets import ChestXray14Dataset # Avoid circular import if disease not in ChestXray14Dataset.classes: msg = f"Invalid disease: '{disease}'! Must be one of {ChestXray14Dataset.classes}." logger.error(msg) raise ValueError(msg) self.disease = disease def __call__(self, patient: Patient) -> List[Dict]: """ Generates binary classification data samples for a single patient. Args: patient (Patient): A patient object containing at least one 'chestxray14' event. Returns: List[Dict]: A list containing a dictionary for each patient visit with: - 'image': path to the chest X-ray image. - 'label': binary label for the specified disease. """ events: List[Event] = patient.get_events(event_type="chestxray14") samples = [] for event in events: samples.append({"image": event["path"], "label": int(event[self.disease])}) return samples