Source code for pyhealth.metrics.fairness_utils.utils
from typing import List
import numpy as np
from pyhealth.datasets import BaseEHRDataset
[docs]def sensitive_attributes_from_patient_ids(dataset: BaseEHRDataset,
patient_ids: List[str],
sensitive_attribute: str,
protected_group: str) -> np.ndarray:
"""
Returns the desired sensitive attribute array from patient_ids.
Args:
dataset: Dataset object.
patient_ids: List of patient IDs.
sensitive_attribute: Sensitive attribute to extract.
protected_group: Value of the protected group.
Returns:
Sensitive attribute array of shape (n_samples,).
"""
sensitive_attribute_array = np.zeros(len(patient_ids))
for idx, patient_id in enumerate(patient_ids):
sensitive_attribute_value = getattr(dataset.patients[patient_id], sensitive_attribute)
if sensitive_attribute_value == protected_group:
sensitive_attribute_array[idx] = 1
return sensitive_attribute_array