Source code for pyhealth.datasets.ehrshot
import logging
from pathlib import Path
from typing import List, Optional
from .base_dataset import BaseDataset
logger = logging.getLogger(__name__)
[docs]class EHRShotDataset(BaseDataset):
"""
A dataset class for handling EHRShot data.
This class is responsible for loading and managing the EHRShot dataset.
Website: https://som-shahlab.github.io/ehrshot-website/
Attributes:
root (str): The root directory where the dataset is stored.
tables (List[str]): A list of tables to be included in the dataset.
dataset_name (Optional[str]): The name of the dataset.
config_path (Optional[str]): The path to the configuration file.
Examples:
>>> from pyhealth.datasets import EHRShotDataset
>>> # Load EHRShot dataset with benchmark tables
>>> dataset = EHRShotDataset(
... root="/path/to/ehrshot/data",
... tables=["ehrshot", "chexpert", "guo_icu", "lab_anemia"],
... )
>>> dataset.stats()
"""
def __init__(
self,
root: str,
tables: List[str],
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
**kwargs,
) -> None:
if config_path is None:
logger.info("No config path provided, using default config")
config_path = Path(__file__).parent / "configs" / "ehrshot.yaml"
super().__init__(
root=root,
tables=tables,
dataset_name=dataset_name or "ehrshot",
config_path=config_path,
**kwargs,
)
return