from collections.abc import Callable, Sequence
from dataclasses import dataclass
from enum import Flag, auto
from typing import Literal, Protocol, override
import numpy as np
from numpy.typing import NDArray
from sklearn.metrics import confusion_matrix
from ._metrics_common import renyi_correlation
__all__ = [
"Float",
"GroupMetric",
"LabelType",
"Metric",
"MetricAgg",
"RenyiCorrelation",
"as_group_metric",
"cv",
"prob_neg",
"prob_pos",
"tnr",
"tpr",
]
type Float = float | np.float16 | np.float32 | np.float64
"""Union of common float types."""
[docs]
class Metric(Protocol):
@property
def __name__(self) -> str:
"""The name of the metric."""
...
self,
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
sample_weight: NDArray[np.bool] | None = ...,
) -> Float: ...
[docs]
class GroupMetric(Protocol):
@property
def __name__(self) -> str:
"""The name of the metric."""
...
self,
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
groups: NDArray[np.int32],
) -> Float: ...
type LabelType = Literal["group", "y"]
"""A type for specifying which labels to use (class or group labels)."""
[docs]
@dataclass
class RenyiCorrelation(GroupMetric):
"""Renyi correlation. Measures how dependent two random variables are.
As defined in this paper: https://link.springer.com/content/pdf/10.1007/BF02024507.pdf ,
titled "On Measures of Dependence" by Alfréd Rényi.
"""
base: LabelType = "group"
"""Which label to use as base to compute the correlation against."""
@property
def __name__(self) -> str:
"""The name of the metric."""
return f"renyi_{self.base}"
[docs]
@override
def __call__(
self,
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
groups: NDArray[np.int32],
) -> float:
return renyi_correlation(x=y_true if self.base == "y" else groups, y=y_pred)
[docs]
def prob_pos(
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
sample_weight: NDArray[np.bool] | None = None,
) -> np.float64:
"""Probability of positive prediction.
example:
>>> import fair_forge as ff
>>> y_true = np.array([0, 0, 0, 1], dtype=np.int32)
>>> y_pred = np.array([0, 1, 0, 1], dtype=np.int32)
>>> ff.metrics.prob_pos(y_true, y_pred)
np.float64(0.5)
"""
_, f_pos, _, t_pos, total = _confusion_matrix(
y_pred=y_pred, y_true=y_true, sample_weight=sample_weight
)
return ((t_pos + f_pos) / total).astype(np.float64)
[docs]
def prob_neg(
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
sample_weight: NDArray[np.bool] | None = None,
) -> np.float64:
"""Probability of negative prediction."""
t_neg, _, f_neg, _, total = _confusion_matrix(
y_pred=y_pred, y_true=y_true, sample_weight=sample_weight
)
return ((t_neg + f_neg) / total).astype(np.float64)
[docs]
def tpr(
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
sample_weight: NDArray[np.bool] | None = None,
) -> np.float64:
"""True Positive Rate (TPR) or Sensitivity."""
_, _, f_neg, t_pos, _ = _confusion_matrix(
y_pred=y_pred, y_true=y_true, sample_weight=sample_weight
)
return (t_pos / (t_pos + f_neg)).astype(np.float64)
[docs]
def tnr(
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
sample_weight: NDArray[np.bool] | None = None,
) -> np.float64:
"""True Negative Rate (TNR) or Specificity."""
t_neg, f_pos, _, _, _ = _confusion_matrix(
y_pred=y_pred, y_true=y_true, sample_weight=sample_weight
)
return (t_neg / (t_neg + f_pos)).astype(np.float64)
def _confusion_matrix(
*,
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
sample_weight: NDArray[np.bool] | None,
) -> tuple[np.int64, np.int64, np.int64, np.int64, np.int64]:
"""Apply sci-kit learn's confusion matrix.
We assume that the positive class is 1.
Returns the 4 entries of the confusion matrix, and the total, as a 5-tuple.
"""
conf_matr: NDArray[np.int64] = confusion_matrix(
y_true=y_true, y_pred=y_pred, normalize=None, sample_weight=sample_weight
)
labels = np.unique(y_true)
pos_class = np.int32(1)
if pos_class not in labels:
raise ValueError("Positive class specified must exist in the true labels.")
# Find the index of the positive class
tp_idx = np.nonzero(labels == pos_class)[0].item()
true_pos = conf_matr[tp_idx, tp_idx]
false_pos = conf_matr[:, tp_idx].sum() - true_pos
false_neg = conf_matr[tp_idx, :].sum() - true_pos
total = conf_matr.sum()
true_neg = total - true_pos - false_pos - false_neg
return true_neg, false_pos, false_neg, true_pos, total
@dataclass
class _AggMetricBase(GroupMetric):
metric: Metric
agg_name: str
remove_score_suffix: bool
@property
def __name__(self) -> str:
"""The name of the metric."""
name = self.metric.__name__
if self.remove_score_suffix and name.endswith("_score"):
name = name[:-6]
return f"{name}_{self.agg_name}"
def _group_scores(
self,
*,
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
groups: NDArray[np.int32],
unique_groups: NDArray[np.int32],
) -> NDArray[np.float64]:
return np.array(
[
self.metric(y_true[groups == group], y_pred[groups == group])
for group in unique_groups
],
dtype=np.float64,
)
@dataclass
class _BinaryAggMetric(_AggMetricBase):
aggregator: Callable[[np.float64, np.float64], np.float64]
@override
def __call__(
self,
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
groups: NDArray[np.int32],
) -> Float:
"""Compute the metric for the given predictions and actual values."""
unique_groups = np.unique(groups)
assert len(unique_groups) == 2, (
f"Aggregation metric with {self.agg_name} requires exactly two groups for aggregation"
)
group_scores = self._group_scores(
y_true=y_true, y_pred=y_pred, groups=groups, unique_groups=unique_groups
)
return self.aggregator(group_scores[0], group_scores[1])
@dataclass
class _MulticlassAggMetric(_AggMetricBase):
aggregator: Callable[[NDArray[np.float64]], Float]
@override
def __call__(
self,
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
groups: NDArray[np.int32],
) -> Float:
"""Compute the metric for the given predictions and actual values."""
unique_groups = np.unique(groups)
group_scores = self._group_scores(
y_true=y_true, y_pred=y_pred, groups=groups, unique_groups=unique_groups
)
return self.aggregator(group_scores)
[docs]
class MetricAgg(Flag):
"""Aggregation methods for metrics that are computed per group."""
INDIVIDUAL = auto()
"""Individual per-group results."""
DIFF = auto()
"""Difference of the per-group results."""
MAX = auto()
"""Maximum of the per-group results."""
MIN = auto()
"""Minimum of the per-group results."""
MIN_MAX = MIN | MAX
"""Equivalent to ``MIN | MAX``."""
RATIO = auto()
"""Ratio of the per-group results."""
DIFF_RATIO = INDIVIDUAL | DIFF | RATIO
"""Equivalent to ``INDIVIDUAL | DIFF | RATIO``."""
ALL = DIFF_RATIO | MIN_MAX
"""All aggregations."""
[docs]
def as_group_metric(
base_metrics: Sequence[Metric],
agg: MetricAgg = MetricAgg.DIFF_RATIO,
remove_score_suffix: bool = True,
) -> list[GroupMetric]:
"""Turn a sequence of metrics into a list of group metrics."""
metrics = []
for metric in base_metrics:
if MetricAgg.DIFF in agg:
metrics.append(
_BinaryAggMetric(
metric=metric,
agg_name="diff",
remove_score_suffix=remove_score_suffix,
aggregator=lambda i, j: j - i,
)
)
if MetricAgg.RATIO in agg:
metrics.append(
_BinaryAggMetric(
metric=metric,
agg_name="ratio",
remove_score_suffix=remove_score_suffix,
aggregator=lambda i, j: i / j if j != 0 else np.float64(np.nan),
)
)
if MetricAgg.MIN in agg:
metrics.append(
_MulticlassAggMetric(
metric=metric,
agg_name="min",
remove_score_suffix=remove_score_suffix,
aggregator=np.min,
)
)
if MetricAgg.MAX in agg:
metrics.append(
_MulticlassAggMetric(
metric=metric,
agg_name="max",
remove_score_suffix=remove_score_suffix,
aggregator=np.max,
)
)
if MetricAgg.INDIVIDUAL in agg:
metrics.append(
_BinaryAggMetric(
metric=metric,
agg_name="0",
remove_score_suffix=remove_score_suffix,
aggregator=lambda i, j: i,
)
)
metrics.append(
_BinaryAggMetric(
metric=metric,
agg_name="1",
remove_score_suffix=remove_score_suffix,
aggregator=lambda i, j: j,
)
)
return metrics
[docs]
def cv(
y_true: NDArray[np.int32],
y_pred: NDArray[np.int32],
*,
groups: NDArray[np.int32],
) -> Float:
"""Calder-Verwer."""
unique_groups = np.unique(groups)
assert len(unique_groups) == 2, (
f"Calder-Verwer requires exactly two groups, got {len(unique_groups)}"
)
group_scores = np.array(
[
prob_pos(y_true[groups == group], y_pred[groups == group])
for group in unique_groups
],
dtype=np.float64,
)
return 1 - (group_scores[1] - group_scores[0])