from dataclasses import dataclass
from enum import Enum
import itertools
from typing import Any, Self
import numpy as np
from numpy.typing import NDArray
from sklearn.base import BaseEstimator
from fair_forge.methods import GroupMethod, Method
from fair_forge.utils import reproducible_random_state
from .definitions import GroupDatasetModifier
__all__ = ["GroupPipeline", "UpsampleStrategy", "Upsampler"]
[docs]
@dataclass
class GroupPipeline(BaseEstimator, GroupMethod):
group_data_modifier: GroupDatasetModifier
estimator: Method
random_state: int | None = None
def __post_init__(self) -> None:
self.update_random_state()
[docs]
def update_random_state(self) -> None:
if self.random_state is not None:
self.group_data_modifier.set_params(random_state=self.random_state)
self.estimator.set_params(random_state=self.random_state)
[docs]
def fit(
self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32]
) -> Self:
# Fit the group pre-processing method and transform the data
self.group_data_modifier.fit(X, y=y, groups=groups)
X_transformed = self.group_data_modifier.transform(X, is_train=True, is_x=True)
# Transform the labels
y_transformed = self.group_data_modifier.transform(y, is_train=True)
# Fit the estimator with the transformed data
self.estimator.fit(X_transformed, y=y_transformed)
self.fitted_ = True
return self
[docs]
def predict(self, X: NDArray[np.float32]) -> NDArray[np.int32]:
# Transform the input data using the group pre-processing method
X_transformed = self.group_data_modifier.transform(X, is_train=False)
return self.estimator.predict(X_transformed)
[docs]
def set_params(self, **params: Any) -> Self:
ret = super().set_params(**params)
self.update_random_state()
return ret
[docs]
class UpsampleStrategy(Enum):
"""Strategy for upsampling."""
UNIFORM = "uniform"
# PREFERENTIAL = "preferential"
NAIVE = "naive"
[docs]
@dataclass
class Upsampler(BaseEstimator, GroupDatasetModifier):
strategy: UpsampleStrategy = UpsampleStrategy.UNIFORM
random_state: int = 0
[docs]
def fit(
self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32]
) -> Self:
s_vals: NDArray[np.int32] = np.unique(groups)
y_vals: NDArray[np.int32] = np.unique(y)
segments: list[tuple[np.int32, np.int32]] = list(
itertools.product(s_vals, y_vals)
)
data: list[tuple[NDArray[np.bool], np.int64, np.int64, np.int64]] = []
for s_val, y_val in segments:
s_y_mask: NDArray[np.bool] = (groups == s_val) & (y == y_val)
y_eq_y = np.count_nonzero(y == y_val)
s_eq_s = np.count_nonzero(groups == s_val)
data.append((s_y_mask, np.count_nonzero(s_y_mask), y_eq_y, s_eq_s))
percentages: list[tuple[NDArray[np.bool], np.float64]] = []
vals = list([d[1] for d in data])
for mask, length, y_eq_y, s_eq_s in data:
if self.strategy is UpsampleStrategy.NAIVE:
percentages.append((mask, (np.max(vals) / length).astype(np.float64)))
else:
num_samples = len(y)
num_batch = length
percentages.append(
(
mask,
(y_eq_y * s_eq_s / (num_batch * num_samples)).astype(
np.float64
),
)
)
self.percentages_ = percentages
return self