Source code for pyhealth.metrics.fairness
from typing import Dict, List, Optional
import numpy as np
from pyhealth.metrics.fairness_utils import disparate_impact, statistical_parity_difference
[docs]def fairness_metrics_fn(
y_true: np.ndarray,
y_prob: np.ndarray,
sensitive_attributes: np.ndarray,
favorable_outcome: int = 1,
metrics: Optional[List[str]] = None,
threshold: float = 0.5,
) -> Dict[str, float]:
"""Computes metrics for binary classification.
User can specify which metrics to compute by passing a list of metric names.
The accepted metric names are:
- disparate_impact:
- statistical_parity_difference:
If no metrics are disparate_impact, and statistical_parity_difference are computed by default.
Args:
y_true: True target values of shape (n_samples,).
y_prob: Predicted probabilities of shape (n_samples,).
sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group.
favorable_outcome: Label value which is considered favorable (i.e. "positive").
metrics: List of metrics to compute. Default is ["disparate_impact", "statistical_parity_difference"].
threshold: Threshold for binary classification. Default is 0.5.
Returns:
Dictionary of metrics whose keys are the metric names and values are
the metric values.
"""
if metrics is None:
metrics = ["disparate_impact", "statistical_parity_difference"]
y_pred = y_prob.copy()
y_pred[y_pred >= threshold] = 1
y_pred[y_pred < threshold] = 0
output = {}
for metric in metrics:
if metric == "disparate_impact":
output[metric] = disparate_impact(sensitive_attributes, y_pred, favorable_outcome)
elif metric == "statistical_parity_difference":
output[metric] = statistical_parity_difference(sensitive_attributes, y_pred, favorable_outcome)
else:
raise ValueError(f"Unknown metric for fairness: {metric}")
return output