Source code for pyhealth.tasks.medical_transcriptions_classification
from typing import Any, Dict, List
from ..data import Patient
from .base_task import BaseTask
[docs]class MedicalTranscriptionsClassification(BaseTask):
"""Task for classifying medical transcriptions into medical specialties.
This task takes medical transcription text as input and predicts the
corresponding medical specialty. It processes patient records containing
mtsamples events and extracts transcription and medical specialty
information.
Attributes:
task_name (str): Name of the task
input_schema (Dict[str, str]): Schema defining input features
output_schema (Dict[str, str]): Schema defining output features
Examples:
>>> from pyhealth.datasets import MedicalTranscriptionsDataset
>>> from pyhealth.tasks import MedicalTranscriptionsClassification
>>> dataset = MedicalTranscriptionsDataset(
... root="/path/to/medical_transcriptions",
... )
>>> task = MedicalTranscriptionsClassification()
>>> samples = dataset.set_task(task)
"""
task_name: str = "MedicalTranscriptionsClassification"
input_schema: Dict[str, str] = {"transcription": "text"}
output_schema: Dict[str, str] = {"medical_specialty": "multiclass"}
def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
"""Process a patient record to extract medical transcription samples.
Args:
patient (Patient): Patient record containing medical
transcription events
Returns:
List[Dict[str, Any]]: List of samples containing transcription
and medical specialty
"""
event = patient.get_events(event_type="mtsamples")
# There should be only one event
assert len(event) == 1
event = event[0]
transcription_valid = isinstance(event.transcription, str)
specialty_valid = isinstance(event.medical_specialty, str)
if transcription_valid and specialty_valid:
sample = {
"id": patient.patient_id,
"transcription": event.transcription,
"medical_specialty": event.medical_specialty,
}
return [sample]
else:
return []