Source code for pyhealth.tasks.bmd_hs_disease_classification

from typing import Any, Dict, List
from .base_task import BaseTask


[docs]class BMDHSDiseaseClassification(BaseTask): """Multi-label classification task for heart valve diseases. This task classifies heart sound recordings into multiple disease categories: - AS (Aortic Stenosis) - AR (Aortic Regurgitation) - MR (Mitral Regurgitation) - MS (Mitral Stenosis) Each patient can have multiple diseases simultaneously (multi-label). The task also provides access to patient metadata including: - Age - Gender - Smoker status (binary) - Living environment (urban/rural, binary) And 8 heart sound recording filenames. Attributes: task_name (str): The name of the task, set to "BMDHSDiseaseClassification". input_schema (Dict[str, str]): The input schema specifying the required input format. Contains: - "recording_1" through "recording_8": "audio" - "age": "float" - "gender": "categorical" - "smoker": "binary" - "lives": "binary" output_schema (Dict[str, str]): The output schema specifying the output format. Contains: - "diagnosis": "multilabel" Examples: >>> from pyhealth.datasets import BMDHSDataset >>> from pyhealth.tasks import BMDHSDiseaseClassification >>> dataset = BMDHSDataset(root="/path/to/bmd_hs") >>> task = BMDHSDiseaseClassification() >>> samples = dataset.set_task(task) """ task_name: str = "BMDHSDiseaseClassification" input_schema: Dict[str, str] = { # 8 heart sound recordings **{f"recording_{i}": "audio" for i in range(1, 9)}, # Patient metadata "age": "regression", "gender": "binary", "smoker": "binary", "lives": "binary", } output_schema: Dict[str, str] = { # Multi-label disease classification "diagnosis": "multilabel", } def __init__(self) -> None: try: import soundfile except ImportError: raise ImportError( "SoundFile library is required for BMDHSDiseaseClassification task to read .wav files. " "Install it with: pip install soundfile" ) super().__init__() def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Process a patient's data to extract features and labels. Args: patient: A patient object containing BMD-HS data. Returns: List[Dict[str, Any]]: A list containing a single dictionary with: - "recording_1" through "recording_8": Paths to recordings - "age": Patient age - "gender": Patient gender - "smoker": Smoking status (binary) - "lives": Living environment (binary) - "diagnosis": A dictionary with disease labels (AS, AR, MR, MS) Raises: AssertionError: If the patient has more than one event. """ pid = patient.patient_id # Extract recording filenames rec_event = patient.get_events(event_type="recordings")[0] recordings = { f"recording_{i}": rec_event[f"recording_{i}"] for i in range(1, 9) } # Extract metadata meta_event = patient.get_events(event_type="metadata")[0] metadata = { "age": float(meta_event["Age"]), "gender": meta_event["Gender"], "smoker": meta_event["Smoker"], "lives": meta_event["Lives"], } # Extract disease labels (multi-label) diag_event = patient.get_events(event_type="diagnoses")[0] labels = { "diagnosis": [ dis for dis in ["AS", "AR", "MR", "MS"] if int(diag_event[dis]) ] } # Combine all features and labels sample = {"patient_id": pid, **recordings, **metadata, **labels} return [sample]