Source code for fair_forge.preprocessing.definitions

from typing import Any, Protocol, Self

import numpy as np
from numpy.typing import NDArray
from sklearn.utils.metadata_routing import MetadataRequest

__all__ = ["GroupBasedTransform", "GroupDatasetModifier", "Preprocessor"]


class _PreprocessorBase(Protocol):
    """A protocol for preprocessing methods."""

    def get_params(self, deep: bool = ...) -> dict[str, object]: ...
    def set_params(self, **kwargs: Any) -> Self: ...
    def get_metadata_routing(self) -> MetadataRequest: ...


[docs] class Preprocessor(_PreprocessorBase, Protocol):
[docs] def fit(self, X: NDArray[np.float32], y: NDArray[np.int32]) -> Self: """Fit the preprocessor to the data.""" ...
[docs] def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]: """Transform the data using the fitted preprocessor.""" ...
[docs] class GroupBasedTransform(_PreprocessorBase, Protocol): """A transformation which is fitted with group information."""
[docs] def fit(
self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32] ) -> Self: ...
[docs] def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]: ...
[docs] def fit_transform(
self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32] ) -> NDArray[np.float32]: ...
[docs] class GroupDatasetModifier(_PreprocessorBase, Protocol): """A transformation which modifies both the dataset and the labels based on group information."""
[docs] def fit( self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32] ) -> Self: """Fit the preprocessing method to the data with group information.""" ...
[docs] def transform[S: np.generic](
self, X: NDArray[S], *, is_train: bool = False, is_x: bool = False ) -> NDArray[S]: ...