Source code for fair_forge.preprocessing.definitions

from typing import Any, Protocol, Self

import numpy as np
from numpy.typing import NDArray
from sklearn.utils.metadata_routing import MetadataRequest

__all__ = ["GroupBasedTransform", "GroupDatasetModifier", "Preprocessor"]


class _PreprocessorBase(Protocol):
    """A protocol for preprocessing methods."""

    def get_params(self, deep: bool = ...) -> dict[str, object]: ...
    def set_params(self, **kwargs: Any) -> Self: ...
    def get_metadata_routing(self) -> MetadataRequest: ...


[docs] class Preprocessor(_PreprocessorBase, Protocol):
[docs] def fit(self, X: NDArray[np.float32], y: NDArray[np.int32]) -> Self: """Fit the preprocessor to the data.""" ...
[docs] def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]: """Transform the data using the fitted preprocessor.""" ...
[docs] class GroupBasedTransform(_PreprocessorBase, Protocol): """A transformation which is fitted with group information."""
[docs] def fit( self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32] ) -> Self: """Fit the transformation to the data with group information.""" ...
[docs] def transform(self, X: NDArray[np.float32]) -> NDArray[np.float32]: """Transform the data using the fitted transformation.""" ...
[docs] def fit_transform( self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32] ) -> NDArray[np.float32]: """Fit the transformation to the data with group information and transform the data.""" ...
[docs] class GroupDatasetModifier(_PreprocessorBase, Protocol): """A transformation which modifies both the dataset and the labels based on group information."""
[docs] def fit( self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32] ) -> Self: """Fit the preprocessing method to the data with group information.""" ...
[docs] def transform[S: np.generic]( self, X: NDArray[S], *, is_train: bool = False, is_x: bool = False ) -> NDArray[S]: """Transform the data using the fitted preprocessing method. Args: X: The data to transform. is_train: Whether the data is training data. This can be used to apply different transformations to training and test data. is_x: Whether the data is features. This can be used to apply different transformations to features and labels. Returns: The transformed data. """ ...