Source code for pyhealth.tasks.chestxray14_multilabel_classification

"""
PyHealth task for multilabel classification using the ChestX-ray14 dataset.

Dataset link:
    https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/36938765345

Dataset paper: (please cite if you use this dataset)
    Xiaosong Wang, Yifan Peng, Le Lu, et al. "ChestX-ray8: Hospital-scale Chest
    X-ray Database and Benchmarks on Weakly-Supervised Classification and
    Localization of Common Thorax Diseases." 2017 IEEE Conference on Computer
    Vision and Pattern Recognition (CVPR), pp. 3462-3471.

Dataset paper link:
    https://arxiv.org/abs/1705.02315

Author:
    Eric Schrock (ejs9@illinois.edu)
"""

import logging
from typing import Dict, List

from pyhealth.data import Event, Patient
from pyhealth.tasks import BaseTask

logger = logging.getLogger(__name__)


[docs]class ChestXray14MultilabelClassification(BaseTask): """ A PyHealth task class for multilabel classification of all fourteen diseases in the ChestXray14 dataset. Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The schema for the task input. output_schema (Dict[str, str]): The schema for the task output. Examples: >>> from pyhealth.datasets import ChestXray14Dataset >>> from pyhealth.tasks import ChestXray14MultilabelClassification >>> dataset = ChestXray14Dataset(root="/path/to/chestxray14") >>> task = ChestXray14MultilabelClassification() >>> samples = dataset.set_task(task) """ task_name: str = "ChestXray14MultilabelClassification" input_schema: Dict[str, str] = {"image": "image"} output_schema: Dict[str, str] = {"labels": "multilabel"} def __call__(self, patient: Patient) -> List[Dict]: """ Generates multilabel classification data samples for a single patient. Args: patient (Patient): A patient object containing at least one 'chestxray14' event. Returns: List[Dict]: A list containing a dictionary for each patient visit with: - 'image': path to the chest X-ray image. - 'labels': a list of labels for diseases present in the image (strings from the list ChestXray14Dataset.classes). """ events: List[Event] = patient.get_events(event_type="chestxray14") samples = [] from pyhealth.datasets import ChestXray14Dataset # Avoid circular import for event in events: samples.append( { "image": event["path"], "labels": [ disease for disease in ChestXray14Dataset.classes if int(event[disease]) ], } ) return samples