Source code for pyhealth.tasks.variant_classification

"""Variant classification tasks for PyHealth.

This module provides tasks for classifying genetic variants based on
their clinical significance using ClinVar and COSMIC datasets.
"""

from typing import Any, Dict, List, Optional

from .base_task import BaseTask


[docs]class VariantClassificationClinVar(BaseTask): """Task for classifying variant clinical significance using ClinVar data. This task predicts the clinical significance of genetic variants (e.g., Pathogenic, Benign, Uncertain significance) based on variant features from the ClinVar database. Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The input schema specifying required inputs. output_schema (Dict[str, str]): The output schema specifying outputs. CLINICAL_SIGNIFICANCE_CATEGORIES (Dict[str, str]): Mapping of raw values to standardized clinical significance labels. Note: Variants with conflicting interpretations or non-standard clinical significance values are excluded from the output samples. Examples: >>> from pyhealth.datasets import ClinVarDataset >>> from pyhealth.tasks import VariantClassificationClinVar >>> dataset = ClinVarDataset(root="/path/to/clinvar") >>> task = VariantClassificationClinVar() >>> samples = dataset.set_task(task) """ task_name: str = "VariantClassificationClinVar" input_schema: Dict[str, str] = { "gene_symbol": "text", "variant_type": "text", "chromosome": "text", } output_schema: Dict[str, str] = {"clinical_significance": "multiclass"} # Standard clinical significance categories (ACMG/AMP guidelines) CLINICAL_SIGNIFICANCE_CATEGORIES: Dict[str, str] = { "pathogenic": "Pathogenic", "likely pathogenic": "Likely pathogenic", "benign": "Benign", "likely benign": "Likely benign", "uncertain significance": "Uncertain significance", "vus": "Uncertain significance", } def _normalize_clinical_significance( self, raw_value: Optional[str] ) -> Optional[str]: """Normalize clinical significance to standard ACMG/AMP categories. Args: raw_value: Raw clinical significance string from ClinVar. Returns: Normalized category string, or None if value is invalid or does not map to a standard category. """ if raw_value is None or raw_value == "" or str(raw_value) == "nan": return None value_lower = str(raw_value).lower() # Check for exact or partial matches if "likely pathogenic" in value_lower: return self.CLINICAL_SIGNIFICANCE_CATEGORIES["likely pathogenic"] elif "pathogenic" in value_lower and "likely" not in value_lower: return self.CLINICAL_SIGNIFICANCE_CATEGORIES["pathogenic"] elif "likely benign" in value_lower: return self.CLINICAL_SIGNIFICANCE_CATEGORIES["likely benign"] elif "benign" in value_lower and "likely" not in value_lower: return self.CLINICAL_SIGNIFICANCE_CATEGORIES["benign"] elif "uncertain" in value_lower or "vus" in value_lower: return self.CLINICAL_SIGNIFICANCE_CATEGORIES["uncertain significance"] # Conflicting or unrecognized categories return None def _safe_str(self, value: Any, default: str = "") -> str: """Safely convert value to string, handling None and NaN. Args: value: Value to convert. default: Default value if conversion fails. Returns: String representation of value or default. """ if value is None or str(value) == "nan": return default return str(value) def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Process a variant record to extract features and label. Args: patient: A patient object containing variant data. In ClinVar, each "patient" represents a single variant record. Returns: List[Dict[str, Any]]: A list containing a single dictionary with variant features and clinical significance label. Returns an empty list if the variant has no events, invalid clinical significance, or conflicting interpretations. Note: Returns empty list for variants with: - No variant events - Missing or empty clinical significance - Conflicting interpretations - Non-standard clinical significance values """ events = patient.get_events(event_type="variants") if len(events) == 0: return [] event = events[0] # Normalize clinical significance raw_clinical_sig = getattr(event, "clinical_significance", None) label = self._normalize_clinical_significance(raw_clinical_sig) if label is None: return [] return [ { "patient_id": patient.patient_id, "gene_symbol": self._safe_str(getattr(event, "gene_symbol", "")), "variant_type": self._safe_str(getattr(event, "variant_type", "")), "chromosome": self._safe_str(getattr(event, "chromosome", "")), "clinical_significance": label, } ]
[docs]class MutationPathogenicityPrediction(BaseTask): """Task for predicting mutation pathogenicity using COSMIC data. This task predicts whether a somatic mutation is pathogenic or neutral based on FATHMM predictions and mutation features from the COSMIC database. Attributes: task_name (str): The name of the task. input_schema (Dict[str, str]): The input schema specifying required inputs. output_schema (Dict[str, str]): The output schema specifying outputs. VALID_FATHMM_PREDICTIONS (tuple): Valid FATHMM prediction values. Note: Only mutations with valid FATHMM predictions (PATHOGENIC or NEUTRAL) are included in the output samples. Examples: >>> from pyhealth.datasets import COSMICDataset >>> from pyhealth.tasks import MutationPathogenicityPrediction >>> dataset = COSMICDataset(root="/path/to/cosmic") >>> task = MutationPathogenicityPrediction() >>> samples = dataset.set_task(task) """ task_name: str = "MutationPathogenicityPrediction" input_schema: Dict[str, str] = { "gene_name": "text", "mutation_description": "text", "primary_site": "text", } output_schema: Dict[str, str] = {"fathmm_prediction": "binary"} # Valid FATHMM prediction categories VALID_FATHMM_PREDICTIONS: tuple = ("PATHOGENIC", "NEUTRAL") def _safe_str(self, value: Any, default: str = "") -> str: """Safely convert value to string, handling None and NaN. Args: value: Value to convert. default: Default value if conversion fails. Returns: String representation of value or default. """ if value is None or str(value) == "nan": return default return str(value) def __call__(self, patient: Any) -> List[Dict[str, Any]]: """Process mutation records to extract features and pathogenicity label. Args: patient: A patient object containing mutation data. In COSMIC, each "patient" represents a sample with mutations. Returns: List[Dict[str, Any]]: A list of dictionaries, one per valid mutation, each containing mutation features and binary pathogenicity label. Returns an empty list if no mutations have valid FATHMM predictions. Note: Only mutations with FATHMM predictions of "PATHOGENIC" or "NEUTRAL" are included. Mutations with missing or other prediction values are excluded. """ events = patient.get_events(event_type="mutations") samples: List[Dict[str, Any]] = [] for event in events: fathmm = getattr(event, "fathmm_prediction", None) if fathmm not in self.VALID_FATHMM_PREDICTIONS: continue samples.append( { "patient_id": patient.patient_id, "gene_name": self._safe_str(getattr(event, "gene_name", "")), "mutation_description": self._safe_str( getattr(event, "mutation_description", "") ), "primary_site": self._safe_str(getattr(event, "primary_site", "")), "fathmm_prediction": 1 if fathmm == "PATHOGENIC" else 0, } ) return samples