Source code for fair_forge.datasets

from pathlib import Path
from typing import Literal, NamedTuple

import numpy as np
from numpy.typing import NDArray
import polars as pl
import polars.selectors as cs

from fair_forge.utils import reproducible_random_state

__all__ = [
    "AdultGroup",
    "GroupDataset",
    "grouping_by_prefix",
    "load_adult",
    "load_dummy_dataset",
    "load_ethicml_toy",
]


[docs] class GroupDataset(NamedTuple): """A dataset containing features, labels, and groups.""" data: NDArray[np.float32] """Features of the dataset.""" target: NDArray[np.int32] """Labels of the dataset.""" groups: NDArray[np.int32] """Groups of the dataset.""" name: str """Name of the dataset.""" feature_grouping: list[slice] """Slices indicating groups of features.""" feature_names: list[str] """Names of the features in the dataset."""
type AdultGroup = Literal["Sex", "Race"]
[docs] def load_adult( group: AdultGroup, *, group_in_features: bool = False, binarize_nationality: bool = False, binarize_race: bool = False, ) -> GroupDataset: """Load the Adult dataset with specified group information. Args: group: The group to use for the dataset. Returns: A Dataset object containing the Adult dataset. """ name = f"Adult {group}" if binarize_nationality: name += ", binary nationality" if binarize_race: name += ", binary race" if group_in_features: name += ", group in features" base_path = Path(__file__).parent df = pl.read_parquet(base_path / "data" / "adult.parquet") y = df.get_column("salary").cat.starts_with(">50K").cast(pl.Int32).to_numpy() df = df.drop("salary") df = df.drop("fnlwgt") column_grouping_prefixes = [ "workclass", "education", "marital-status", "occupation", "relationship", "race", "sex", "native-country", ] to_dummies = cs.categorical() if binarize_race: df = df.with_columns( pl.col("race").replace_strict( {"White": "White"}, default="Other", return_dtype=pl.Enum(["White", "Other"]), ) ) to_dummies = to_dummies | cs.by_name("race") if binarize_nationality: df = df.with_columns( pl.col("native-country").replace_strict( {"United-States": "United-States"}, default="Other", return_dtype=pl.Enum(["United-States", "Other"]), ) ) to_dummies = to_dummies | cs.by_name("native-country") groups: NDArray[np.int32] to_drop: str match group: case "Sex": groups = ( df.get_column("sex").cat.starts_with("Male").cast(pl.Int32).to_numpy() ) to_drop = "sex" case "Race": # `.to_physical()` converts the categorical column to its physical representation, # which is UInt32 by default in Polars. groups = df.get_column("race").to_physical().cast(pl.Int32).to_numpy() to_drop = "race" case _: raise ValueError(f"Invalid group: {group}") if not group_in_features: df = df.drop(to_drop) column_grouping_prefixes.remove(to_drop) # Convert categorical columns to one-hot encoded features df = df.to_dummies(to_dummies, separator=":") columns = df.columns feature_grouping = grouping_by_prefix( columns=columns, prefixes=[f"{col}:" for col in column_grouping_prefixes] ) features = df.cast(pl.Float32).to_numpy() return GroupDataset( data=features, target=y, groups=groups, name=name, feature_grouping=feature_grouping, feature_names=columns, )
[docs] def grouping_by_prefix(*, columns: list[str], prefixes: list[str]) -> list[slice]: """Create slices for feature grouping based on column prefixes.""" feature_grouping: list[slice] = [] for prefix in prefixes: # Find the indices of columns that start with the prefix indices = [i for i, col in enumerate(columns) if col.startswith(prefix)] if not indices: raise ValueError(f"No columns found with prefix '{prefix}'.") start = min(indices) end = max(indices) + 1 assert all(i in indices for i in range(start, end)), ( f"The columns correponding to prefix '{prefix}' are not contiguous." ) feature_grouping.append(slice(start, end)) return feature_grouping
[docs] def load_dummy_dataset(seed: int) -> GroupDataset: """Load a dummy dataset for testing purposes, based on a mixture of 2 2D Gaussians. The groups are random. Args: seed: Random seed for reproducibility. """ generator = reproducible_random_state(seed) n_samples = 100 n_features = 2 n_groups = 2 # Diagonal covariance matrix for the 2D Gaussian cov = np.eye(n_features) # Identity matrix for covariance # First generate samples for the first class (n=n_samples // 2) x1 = generator.multivariate_normal(mean=[0.0, 0.0], cov=cov, size=n_samples // 2) y1 = np.zeros(n_samples // 2, dtype=np.int32) groups1 = generator.integers(0, n_groups, size=n_samples // 2, dtype=np.int32) # Then generate samples for the second class (n=n_samples // 2) x2 = generator.multivariate_normal(mean=[1.5, 1.5], cov=cov, size=n_samples // 2) y2 = np.ones(n_samples // 2, dtype=np.int32) groups2 = generator.integers(0, n_groups, size=n_samples // 2, dtype=np.int32) # Concatenate the samples x = np.concatenate((x1, x2), axis=0).astype(np.float32) y = np.concatenate((y1, y2), axis=0) groups = np.concatenate((groups1, groups2), axis=0) # Create feature names feature_names = [f"feature_{i}" for i in range(n_features)] # Create feature grouping (no groupings) feature_grouping = [] name = "Dummy Dataset" return GroupDataset( data=x, target=y, groups=groups, name=name, feature_grouping=feature_grouping, feature_names=feature_names, )
[docs] def load_ethicml_toy(group_in_features: bool = False) -> GroupDataset: """Load the EthicML toy dataset.""" base_path = Path(__file__).parent df = pl.read_parquet(base_path / "data" / "toy.parquet") y = df.get_column("decision").cast(pl.Int32).to_numpy() df = df.drop("decision") groups = df.get_column("sensitive-attr").cast(pl.Int32).to_numpy() if not group_in_features: # If the group is not supposed to be in the features, we drop it df = df.drop("sensitive-attr") discrete_columns = ["disc_1", "disc_2"] df = df.to_dummies(discrete_columns, separator=":") features = df.cast(pl.Float32).to_numpy() feature_names = df.columns feature_grouping = grouping_by_prefix( columns=feature_names, prefixes=discrete_columns ) return GroupDataset( data=features, target=y, groups=groups, name="Toy", feature_grouping=feature_grouping, feature_names=feature_names, )