Source code for pyhealth.data.data

import operator
from dataclasses import dataclass, field
from datetime import datetime
from functools import reduce
from typing import Dict, List, Mapping, Optional, Union

import numpy as np
import polars as pl


[docs]@dataclass(frozen=True) class Event: """Event class representing a single clinical event. Attributes: event_type (str): Type of the clinical event (e.g., 'medication', 'diagnosis') timestamp (datetime): When the event occurred attr_dict (Mapping[str, any]): Dictionary containing event-specific attributes """ event_type: str timestamp: datetime attr_dict: Mapping[str, any] = field(default_factory=dict) def __init__(self, event_type: str, timestamp: datetime = None, **kwargs): """Initialize an Event instance. Args: event_type (str): Type of the clinical event timestamp (datetime, optional): When the event occurred. If not provided, current time will be used. **kwargs: Additional attributes to store in attr_dict """ # Create a mutable copy of kwargs to manipulate attr_dict = dict(kwargs) # Extract existing attr_dict if provided in kwargs if "attr_dict" in attr_dict: existing_attr_dict = attr_dict.pop("attr_dict") # Merge with remaining kwargs, with kwargs taking precedence attr_dict = {**existing_attr_dict, **attr_dict} # Set timestamp to current time if not provided if timestamp is None: timestamp = datetime.now() # Use object.__setattr__ since the dataclass is frozen object.__setattr__(self, "event_type", event_type) object.__setattr__(self, "timestamp", timestamp) object.__setattr__(self, "attr_dict", attr_dict)
[docs] @classmethod def from_dict(cls, d: Dict[str, any]) -> "Event": """Create an Event instance from a dictionary. Args: d (Dict[str, any]): Dictionary containing event data. Returns: Event: An instance of the Event class. """ timestamp: datetime = d["timestamp"] event_type: str = d["event_type"] attr_dict: Dict[str, any] = { k.split("/", 1)[1]: v for k, v in d.items() if k.split("/")[0] == event_type } return cls(event_type=event_type, timestamp=timestamp, attr_dict=attr_dict)
def __getitem__(self, key: str) -> any: """Get an attribute by key. Args: key (str): The key of the attribute to retrieve. Returns: any: The value of the attribute. """ if key == "timestamp": return self.timestamp elif key == "event_type": return self.event_type else: return self.attr_dict[key] def __contains__(self, key: str) -> bool: """Check if an attribute exists by key. Args: key (str): The key of the attribute to check. Returns: bool: True if the attribute exists, False otherwise. """ if key == "timestamp" or key == "event_type": return True return key in self.attr_dict def __getattr__(self, key: str) -> any: """Get an attribute using dot notation. Args: key (str): The key of the attribute to retrieve. Returns: any: The value of the attribute. Raises: AttributeError: If the attribute does not exist. """ if key == "timestamp" or key == "event_type": return getattr(self, key) if key in self.attr_dict: return self.attr_dict[key] raise AttributeError(f"'Event' object has no attribute '{key}'")
[docs]class Patient: """Patient class representing a sequence of events. Attributes: patient_id (str): Unique patient identifier. data_source (pl.DataFrame): DataFrame containing all events, sorted by timestamp. event_type_partitions (Dict[str, pl.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions. """ def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None: """ Initialize a Patient instance. Args: patient_id (str): Unique patient identifier. data_source (pl.DataFrame): DataFrame containing all events. """ self.patient_id = patient_id self.data_source = data_source.sort("timestamp") self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True) def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: """Regular filtering by time. Time complexity: O(n).""" if start is not None: df = df.filter(pl.col("timestamp") >= start) if end is not None: df = df.filter(pl.col("timestamp") <= end) return df def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame: """Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n).""" if start is None and end is None: return df df = df.filter(pl.col("timestamp").is_not_null()) ts_col = df["timestamp"].to_numpy() start_idx = 0 end_idx = len(ts_col) if start is not None: start_idx = np.searchsorted(ts_col, np.datetime64(start, "ms"), side="left") if end is not None: end_idx = np.searchsorted(ts_col, np.datetime64(end, "ms"), side="right") return df.slice(start_idx, end_idx - start_idx) def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: """Regular filtering by event type. Time complexity: O(n).""" if event_type: df = df.filter(pl.col("event_type") == event_type) return df def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame: """Fast filtering by event type using pre-built event type index. Time complexity: O(1).""" if event_type: return self.event_type_partitions.get((event_type,), df[:0]) else: return df
[docs] def get_events( self, event_type: Optional[str] = None, start: Optional[datetime] = None, end: Optional[datetime] = None, filters: Optional[List[tuple]] = None, return_df: bool = False, ) -> Union[pl.DataFrame, List[Event]]: """Get events with optional type and time filters. Args: event_type (Optional[str]): Type of events to filter. start (Optional[datetime]): Start time for filtering events. end (Optional[datetime]): End time for filtering events. return_df (bool): Whether to return a DataFrame or a list of Event objects. filters (Optional[List[tuple]]): Additional filters as [(attr, op, value), ...], e.g.: [("attr1", "!=", "abnormal"), ("attr2", "!=", 1)]. Filters are applied after type and time filters. The logic is "AND" between different filters. Returns: Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame or a list of Event objects. """ # faster filtering (by default) df = self._filter_by_event_type_fast(self.data_source, event_type) df = self._filter_by_time_range_fast(df, start, end) # regular filtering (commented out by default) # df = self._filter_by_event_type_regular(self.data_source, event_type) # df = self._filter_by_time_range_regular(df, start, end) if filters: assert event_type is not None, "event_type must be provided if filters are provided" else: filters = [] exprs = [] for filt in filters: if not (isinstance(filt, tuple) and len(filt) == 3): raise ValueError( f"Invalid filter format: {filt} (must be tuple of (attr, op, value))" ) attr, op, val = filt col_expr = pl.col(f"{event_type}/{attr}") # Build operator expression if op == "==": exprs.append(col_expr == val) elif op == "!=": exprs.append(col_expr != val) elif op == "<": exprs.append(col_expr < val) elif op == "<=": exprs.append(col_expr <= val) elif op == ">": exprs.append(col_expr > val) elif op == ">=": exprs.append(col_expr >= val) else: raise ValueError(f"Unsupported operator: {op} in filter {filt}") if exprs: df = df.filter(reduce(operator.and_, exprs)) if return_df: return df return [Event.from_dict(d) for d in df.to_dicts()]