"""Generic FHIR ingestion using flattened resource tables.
Architecture
------------
1. Stream NDJSON/NDJSON.GZ FHIR resources from disk.
2. Normalize each resource type into a 2D table via a declarative
:class:`~pyhealth.datasets.fhir.utils.ResourceSpec` registry
(``self.resource_specs``) — see :mod:`~pyhealth.datasets.fhir.utils`.
3. Feed those tables through the standard YAML-driven
:class:`~pyhealth.datasets.BaseDataset` pipeline so downstream task
processing operates on :class:`~pyhealth.data.Patient` and
``global_event_df`` rows.
``FHIRDataset`` is generic: it owns the streaming/cache/validation machinery but
no specific resource specs or config. Use it directly by passing
``resource_specs=`` + ``config_path=``, or subclass it for a concrete source
(e.g. :class:`~pyhealth.datasets.fhir.mimic4.MIMIC4FHIR`) that bakes those in as
class attributes.
Authors:
John Wu and Evan Febrianto
"""
from __future__ import annotations
import functools
import hashlib
import logging
import operator
import shutil
import uuid
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence
import dask.dataframe as dd
import narwhals as nw
import orjson
import pandas as pd
import platformdirs
from yaml import safe_load
from ..base_dataset import BaseDataset
from .utils import (
FHIR_SCHEMA_VERSION,
SUPPORTED_OUTPUT_FORMATS,
ResourceSpec,
filter_flat_tables_by_patient_ids,
load_resource_specs_from_yaml,
sorted_patient_ids_from_flat_tables,
stream_fhir_ndjson_to_flat_tables,
table_file_name,
tables_from_specs,
)
logger = logging.getLogger(__name__)
def read_fhir_settings_yaml(path: str) -> Dict[str, Any]:
with open(path, encoding="utf-8") as stream:
data = safe_load(stream)
return data if isinstance(data, dict) else {}
def _strip_tz_to_naive_ms(part: pd.Series) -> pd.Series:
if getattr(part.dtype, "tz", None) is not None:
part = part.dt.tz_localize(None)
return part.astype("datetime64[ms]")
[docs]class FHIRDataset(BaseDataset):
"""FHIR resources flattened into per-type tables, then the standard pipeline.
Streams raw FHIR NDJSON/NDJSON.GZ exports into flattened tables (one per
configured resource type) and pipelines them through
:class:`~pyhealth.datasets.BaseDataset` for downstream task processing
(global event dataframe, patient iteration, task sampling).
The entire ingest is driven by a single YAML config with three top-level
sections — ``glob_patterns:`` (which NDJSON files to open),
``resource_specs:`` (how to project each FHIR resource type into a flat
row), and ``tables:`` (how those rows are exposed as events downstream).
See ``pyhealth/datasets/fhir/configs/mimic4fhir.yaml`` for a complete
worked example and the FHIRDataset rst page for a section-by-section guide.
Pass ``config_path=...`` directly, or subclass and set
``DEFAULT_CONFIG_PATH`` to bundle a default (see
:class:`~pyhealth.datasets.fhir.mimic4.MIMIC4FHIR`).
Args:
root: Path to the NDJSON/NDJSON.GZ export directory.
config_path: Path to the FHIR ingest YAML. Defaults to the class
attribute ``DEFAULT_CONFIG_PATH``. The YAML must contain a
``resource_specs:`` block; any ``glob_patterns:`` and ``tables:``
blocks are also read from here.
glob_pattern: Single glob for NDJSON files; overrides the YAML's
``glob_patterns``. Mutually exclusive with *glob_patterns*.
glob_patterns: Multiple glob patterns; overrides the YAML's
``glob_patterns``. Mutually exclusive with *glob_pattern*.
output_format: Flat-table format, one of ``parquet`` (default),
``csv``, ``tsv``. Defaults to the class attribute
``DEFAULT_OUTPUT_FORMAT``.
max_patients: Limit ingest to the first *N* unique patient IDs.
ingest_num_shards: Ignored; retained for API compatibility.
cache_dir: Cache directory root (UUID subdir appended per config).
num_workers: Worker processes for task sampling.
dev: Development mode; limits to 1000 patients if *max_patients* is
``None``.
Examples:
>>> # ad-hoc, no subclass
>>> ds = FHIRDataset(
... root="/data/fhir",
... config_path="my_fhir.yaml",
... )
>>> # or a preconfigured source subclass
>>> from pyhealth.datasets import MIMIC4FHIR
>>> ds = MIMIC4FHIR(root="/data/mimic-iv-fhir", max_patients=500)
"""
#: Default ingest YAML path; set by source subclasses to bundle a config.
DEFAULT_CONFIG_PATH: Optional[str] = None
#: Default flat-table output format.
DEFAULT_OUTPUT_FORMAT: str = "parquet"
#: Dataset name used for cache identity / logging.
DATASET_NAME: str = "fhir"
def __init__(
self,
root: str,
config_path: Optional[str] = None,
glob_pattern: Optional[str] = None,
glob_patterns: Optional[Sequence[str]] = None,
output_format: Optional[str] = None,
max_patients: Optional[int] = None,
ingest_num_shards: Optional[int] = None,
cache_dir: Optional[str | Path] = None,
num_workers: int = 1,
dev: bool = False,
) -> None:
del ingest_num_shards
resolved_config = config_path or type(self).DEFAULT_CONFIG_PATH
if resolved_config is None:
raise ValueError(
"FHIRDataset requires config_path: pass config_path=... or use a "
"subclass that defines DEFAULT_CONFIG_PATH."
)
self._fhir_config_path = str(Path(resolved_config).resolve())
self._fhir_settings = read_fhir_settings_yaml(self._fhir_config_path)
# Section 2 of the YAML: how each FHIR resource type projects into a row.
self.resource_specs: Mapping[str, ResourceSpec] = (
load_resource_specs_from_yaml(self._fhir_settings)
)
# Cross-validate: every table the specs declare must have a downstream
# `tables:` block (Section 3). Catches typos at startup.
spec_tables = set(tables_from_specs(self.resource_specs))
declared_tables = set((self._fhir_settings.get("tables") or {}).keys())
missing = spec_tables - declared_tables
if missing:
raise ValueError(
f"config {self._fhir_config_path}: resource_specs references "
f"table(s) {sorted(missing)} not declared in the 'tables:' "
f"block. Add a matching tables.<name> entry (patient_id, "
f"timestamp, attributes) for each."
)
self.output_format = output_format or type(self).DEFAULT_OUTPUT_FORMAT
if self.output_format not in SUPPORTED_OUTPUT_FORMATS:
raise ValueError(
f"Unsupported output_format {self.output_format!r}; "
f"expected one of {SUPPORTED_OUTPUT_FORMATS}."
)
if glob_pattern is not None and glob_patterns is not None:
raise ValueError("Pass at most one of glob_pattern and glob_patterns.")
if glob_patterns is not None:
self.glob_patterns: List[str] = list(glob_patterns)
elif glob_pattern is not None:
self.glob_patterns = [glob_pattern]
else:
raw_list = self._fhir_settings.get("glob_patterns")
if raw_list:
if not isinstance(raw_list, list):
raise TypeError("config glob_patterns must be a list of strings.")
self.glob_patterns = [str(x) for x in raw_list]
elif self._fhir_settings.get("glob_pattern") is not None:
self.glob_patterns = [str(self._fhir_settings["glob_pattern"])]
else:
self.glob_patterns = ["**/*.ndjson.gz"]
self.glob_pattern = (
self.glob_patterns[0]
if len(self.glob_patterns) == 1
else "; ".join(self.glob_patterns)
)
self.max_patients = 1000 if dev and max_patients is None else max_patients
self._fhir_tables = tables_from_specs(self.resource_specs)
resolved_root = str(Path(root).expanduser().resolve())
super().__init__(
root=resolved_root,
tables=list(self._fhir_tables),
dataset_name=type(self).DATASET_NAME,
config_path=self._fhir_config_path,
cache_dir=cache_dir,
num_workers=num_workers,
dev=dev,
)
# ------------------------------------------------------------------
# Cache identity
# ------------------------------------------------------------------
def _init_cache_dir(self, cache_dir: str | Path | None) -> Path:
try:
yaml_digest = hashlib.sha256(
Path(self._fhir_config_path).read_bytes()
).hexdigest()[:16]
except OSError:
yaml_digest = "missing"
identity = orjson.dumps(
{
"root": self.root,
"tables": sorted(self.tables),
"dataset_name": self.dataset_name,
"dev": self.dev,
"glob_patterns": self.glob_patterns,
"max_patients": self.max_patients,
"output_format": self.output_format,
"fhir_schema_version": FHIR_SCHEMA_VERSION,
"fhir_yaml_digest16": yaml_digest,
},
option=orjson.OPT_SORT_KEYS,
).decode("utf-8")
cache_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, identity))
out = (
Path(platformdirs.user_cache_dir(appname="pyhealth")) / cache_id
if cache_dir is None
else Path(cache_dir) / cache_id
)
out.mkdir(parents=True, exist_ok=True)
logger.info(f"Cache dir: {out}")
return out
# ------------------------------------------------------------------
# NDJSON -> flat tables ingest
# ------------------------------------------------------------------
@property
def prepared_tables_dir(self) -> Path:
return self.cache_dir / "flattened_tables"
def _ensure_prepared_tables(self) -> None:
root = Path(self.root)
if not root.is_dir():
raise FileNotFoundError(f"FHIR root not found: {root}")
expected = [
self.prepared_tables_dir / table_file_name(t, self.output_format)
for t in self._fhir_tables
]
if all(p.is_file() for p in expected):
return
if self.prepared_tables_dir.exists():
shutil.rmtree(self.prepared_tables_dir)
try:
staging_root = self.create_tmpdir()
staging = staging_root / "flattened_fhir_tables"
staging.mkdir(parents=True, exist_ok=True)
stream_fhir_ndjson_to_flat_tables(
root,
self.glob_patterns,
staging,
self.resource_specs,
self.output_format,
)
if self.max_patients is None:
shutil.move(str(staging), str(self.prepared_tables_dir))
return
filtered_root = self.create_tmpdir()
filtered = filtered_root / "filtered"
pids = sorted_patient_ids_from_flat_tables(
staging, self._fhir_tables, self.output_format
)
filter_flat_tables_by_patient_ids(
staging,
filtered,
pids[: self.max_patients],
self._fhir_tables,
self.output_format,
)
shutil.move(str(filtered), str(self.prepared_tables_dir))
finally:
self.clean_tmpdir()
def _event_transform(self, output_dir: Path) -> None:
self._ensure_prepared_tables()
super()._event_transform(output_dir)
# ------------------------------------------------------------------
# Table loading (flat tables instead of source CSVs)
# ------------------------------------------------------------------
def _read_flat_table(self, path: Path) -> dd.DataFrame:
if self.output_format == "parquet":
return dd.read_parquet(
str(path), split_row_groups=True, blocksize="64MB"
).replace("", pd.NA)
sep = "\t" if self.output_format == "tsv" else ","
return dd.read_csv(
str(path), sep=sep, dtype=str, blocksize="64MB"
).replace("", pd.NA)
[docs] def load_table(self, table_name: str) -> dd.DataFrame:
"""Load one flattened table into the standard event schema.
Deviations from ``BaseDataset.load_table`` (CSV via ``_scan_csv_tsv_gz``):
* Reads pre-built flat tables (parquet/csv/tsv) under
``prepared_tables_dir``.
* Timestamp parsing uses ``errors="coerce"`` + ``utc=True`` (FHIR ISO
strings include timezone suffix or partial dates).
* Strips tz-aware timestamps to naive UTC for Dask compat.
* Drops rows with null ``patient_id`` before returning.
"""
assert self.config is not None
if table_name not in self.config.tables:
raise ValueError(f"Table {table_name} not found in config")
table_cfg = self.config.tables[table_name]
path = self.prepared_tables_dir / table_file_name(
table_name, self.output_format
)
if not path.exists():
raise FileNotFoundError(f"Flattened table not found: {path}")
logger.info(f"Scanning FHIR flattened table: {table_name} from {path}")
df: dd.DataFrame = self._read_flat_table(path)
df = df.rename(columns=str.lower)
preprocess_func = getattr(self, f"preprocess_{table_name}", None)
if preprocess_func is not None:
logger.info(
f"Preprocessing FHIR table: {table_name} "
f"with {preprocess_func.__name__}"
)
df = preprocess_func(nw.from_native(df)).to_native() # type: ignore[union-attr]
for join_cfg in table_cfg.join:
join_path = self.prepared_tables_dir / Path(join_cfg.file_path).name
if not join_path.exists():
raise FileNotFoundError(f"FHIR join table not found: {join_path}")
logger.info(f"Joining FHIR table {table_name} with {join_path}")
join_df: dd.DataFrame = self._read_flat_table(join_path)
join_df = join_df.rename(columns=str.lower)
join_key = join_cfg.on.lower()
cols = [c.lower() for c in join_cfg.columns]
df = df.merge(join_df[[join_key] + cols], on=join_key, how=join_cfg.how)
ts_col = table_cfg.timestamp
if ts_col:
ts = (
functools.reduce(
operator.add,
(df[c].astype("string") for c in ts_col),
)
if isinstance(ts_col, list)
else df[ts_col].astype("string")
)
ts = dd.to_datetime(
ts, format=table_cfg.timestamp_format, errors="coerce", utc=True
)
df = df.assign(timestamp=ts.map_partitions(_strip_tz_to_naive_ms))
else:
df = df.assign(timestamp=pd.NaT)
if table_cfg.patient_id:
df = df.assign(patient_id=df[table_cfg.patient_id].astype("string"))
else:
df = df.reset_index(drop=True)
df = df.assign(patient_id=df.index.astype("string"))
df = df.dropna(subset=["patient_id"])
df = df.assign(event_type=table_name)
rename_attr = {
attr.lower(): f"{table_name}/{attr}" for attr in table_cfg.attributes
}
df = df.rename(columns=rename_attr)
return df[
["patient_id", "event_type", "timestamp"]
+ [rename_attr[a.lower()] for a in table_cfg.attributes]
]
# ------------------------------------------------------------------
# Patient IDs (deterministic sorted order)
# ------------------------------------------------------------------
@property
def unique_patient_ids(self) -> List[str]:
if self._unique_patient_ids is None:
self._unique_patient_ids = (
self.global_event_df.select("patient_id")
.unique()
.sort("patient_id")
.collect(engine="streaming")
.to_series()
.to_list()
)
logger.info(f"Found {len(self._unique_patient_ids)} unique patient IDs")
return self._unique_patient_ids