Source code for pyhealth.tasks.covid19_cxr_classification
from typing import Any, Dict, List
from .base_task import BaseTask
[docs]class COVID19CXRClassification(BaseTask):
"""A task for classifying chest disease from chest X-ray images.
This task classifies chest X-ray images into different disease categories.
It expects a single chest X-ray image per patient and returns the
corresponding disease label.
Attributes:
task_name (str): The name of the task, set to
"COVID19CXRClassification".
input_schema (Dict[str, str]): The input schema specifying the required
input format. Contains a single key "image" with value "image".
output_schema (Dict[str, str]): The output schema specifying the output
format. Contains a single key "disease" with value "multiclass".
Examples:
>>> from pyhealth.datasets import COVID19CXRDataset
>>> from pyhealth.tasks import COVID19CXRClassification
>>> dataset = COVID19CXRDataset(root="/path/to/covid19_cxr")
>>> task = COVID19CXRClassification()
>>> samples = dataset.set_task(task)
"""
task_name: str = "COVID19CXRClassification"
input_schema: Dict = {"image": ("image", {"mode": "RGB"})}
output_schema: Dict[str, str] = {"disease": "multiclass"}
def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"""Process a patient's chest X-ray data to classify COVID-19 status.
Args:
patient: A patient object containing chest X-ray data.
Returns:
List[Dict[str, Any]]: A list containing a single dictionary with:
- "image": Path to the chest X-ray image
- "disease": The disease classification label
Raises:
AssertionError: If the patient has more than one chest X-ray event.
"""
event = patient.get_events(event_type="covid19_cxr")
# There should be only one event
assert len(event) == 1
event = event[0]
image = event.path
disease = event.label
return [{"image": image, "disease": disease}]