"""A method for fair representations."""
from collections.abc import Sequence
from dataclasses import dataclass, field
from itertools import islice
from typing import Self
from flax import nnx
from flax_typed import jit, value_and_grad
from jax import Array
import jax.numpy as jnp
import numpy as np
from numpy.typing import NDArray
import optax # type: ignore[import]
from sklearn.base import BaseEstimator
from fair_forge.methods import FairnessType
from fair_forge.nn.utils import grad_reverse, iterate_forever
from fair_forge.preprocessing.definitions import GroupBasedTransform
from fair_forge.utils import batched
__all__ = ["Beutel"]
class Block(nnx.Module):
def __init__(self, in_features: int, out_features: int, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
def __call__(self, x: Array) -> Array:
out = self.linear(x)
out = nnx.sigmoid(out)
return out
class Encoder(nnx.Module):
def __init__(self, enc_size: Sequence[int], init_size: int, *, rngs: nnx.Rngs):
layers: list[Block] = []
if not enc_size: # In the case that encoder size [] is specified
layers.append(Block(init_size, init_size, rngs=rngs))
else:
layers.append(Block(init_size, enc_size[0], rngs=rngs))
for k in range(len(enc_size) - 1):
layers.append(Block(enc_size[k], enc_size[k + 1], rngs=rngs))
self.encoder = nnx.Sequential(*layers)
def __call__(self, x: Array) -> Array:
return self.encoder(x)
class Adversary(nnx.Module):
"""Adversary of the GAN."""
def __init__(
self,
fairness: FairnessType,
adv_size: Sequence[int],
init_size: int,
s_size: int,
adv_weight: float,
*,
rngs: nnx.Rngs,
):
self.fairness = fairness
self.init_size = init_size
self.adv_weight = adv_weight
layers: list[Block | nnx.Linear] = []
if not adv_size: # In the case that encoder size [] is specified
layers.append(nnx.Linear(init_size, s_size, rngs=rngs))
else:
layers.append(Block(init_size, adv_size[0], rngs=rngs))
for k in range(len(adv_size) - 1):
layers.append(Block(adv_size[k], adv_size[k + 1], rngs=rngs))
layers.append(nnx.Linear(adv_size[-1], s_size, rngs=rngs))
self.adversary = nnx.Sequential(*layers)
def __call__(self, x: Array) -> Array:
x = grad_reverse(x, lambda_=self.adv_weight)
x = self.adversary(x)
return x
class Predictor(nnx.Module):
"""Predictor of the GAN."""
def __init__(
self,
pred_size: Sequence[int],
init_size: int,
class_label_size: int,
*,
rngs: nnx.Rngs,
):
super().__init__()
layers: list[Block | nnx.Linear] = []
if not pred_size: # In the case that encoder size [] is specified
layers.append(Block(init_size, class_label_size, rngs=rngs))
else:
layers.append(Block(init_size, pred_size[0], rngs=rngs))
for k in range(len(pred_size) - 1):
layers.append(Block(pred_size[k], pred_size[k + 1], rngs=rngs))
layers.append(nnx.Linear(pred_size[-1], class_label_size, rngs=rngs))
self.predictor = nnx.Sequential(*layers)
def __call__(self, x: Array) -> Array:
return self.predictor(x)
class Model(nnx.Module):
"""Whole GAN model."""
def __init__(
self,
enc_size: Sequence[int],
adv_size: Sequence[int],
pred_size: Sequence[int],
adv_weight: float,
fairness: FairnessType,
x_size: int,
s_size: int,
y_size: int,
*,
rngs: nnx.Rngs,
) -> None:
self.enc = Encoder(
enc_size=enc_size,
init_size=x_size,
rngs=rngs,
)
self.adv = Adversary(
fairness=fairness,
adv_size=adv_size,
init_size=enc_size[-1] if enc_size else x_size,
s_size=s_size,
adv_weight=adv_weight,
rngs=rngs,
)
self.pred = Predictor(
pred_size=pred_size,
init_size=enc_size[-1] if enc_size else x_size,
class_label_size=y_size,
rngs=rngs,
)
def __call__(self, x: Array) -> tuple[Array, Array, Array]:
encoded = self.enc(x)
s_hat = self.adv(encoded)
y_hat = self.pred(encoded)
return encoded, s_hat, y_hat
[docs]
@dataclass
class Beutel(BaseEstimator, GroupBasedTransform):
enc_size: list[int] = field(default_factory=lambda: [40])
adv_size: list[int] = field(default_factory=lambda: [40])
pred_size: list[int] = field(default_factory=lambda: [40])
adv_weight: float = 1.0
fairness: FairnessType = FairnessType.DP
batch_size: int = 64
iters: int = 500
random_state: int = 42
learning_rate: float = 0.005
momentum: float = 0.9
[docs]
def fit(
self, X: NDArray[np.float32], y: NDArray[np.int32], *, groups: NDArray[np.int32]
) -> Self:
x_size = X.shape[1]
y_size = classes if (classes := len(np.unique(y))) > 2 else 1
s_size = n_groups if (n_groups := len(np.unique(groups))) > 2 else 1
def loss_fn(model: Model, x: Array, y: Array, s: Array) -> Array:
_, s_hat, y_hat = model(x)
s_hat = s_hat.squeeze(-1)
y_hat = y_hat.squeeze(-1)
if y_size > 1:
predictor_loss = optax.softmax_cross_entropy_with_integer_labels(
logits=y_hat, labels=y
).mean()
else:
predictor_loss = optax.sigmoid_binary_cross_entropy(
logits=y_hat, labels=y
).mean()
match self.fairness:
case FairnessType.EQ_OPP:
mask = y > 0.5
case FairnessType.EQ_ODDS:
raise NotImplementedError("Not implemented Eq. Odds yet")
case FairnessType.DP:
mask = jnp.ones(s.shape, dtype=jnp.bool)
if s_size > 1:
adversary_loss = optax.softmax_cross_entropy_with_integer_labels(
logits=s_hat, labels=s, where=mask
).mean()
else:
adversary_loss = optax.sigmoid_binary_cross_entropy(
logits=s_hat, labels=s
).mean()
loss = predictor_loss + adversary_loss
return loss
@jit
def train_step(
model: Model,
optimizer: nnx.Optimizer,
metrics: nnx.MultiMetric,
x: Array,
y: Array,
s: Array,
) -> None:
"""Train for a single step."""
grad_fn = value_and_grad(loss_fn, has_aux=False)
loss, grads = grad_fn(model, x, y, s)
metrics.update(loss=loss)
optimizer.update(grads)
model = Model(
enc_size=self.enc_size,
adv_size=self.adv_size,
pred_size=self.pred_size,
adv_weight=self.adv_weight,
fairness=self.fairness,
x_size=x_size,
y_size=y_size,
s_size=s_size,
rngs=nnx.Rngs(self.random_state),
)
optimizer = nnx.Optimizer(model, optax.adamw(self.learning_rate, self.momentum))
metrics = nnx.MultiMetric(loss=nnx.metrics.Average("loss"))
dataloader = iterate_forever(
(jnp.asarray(X), jnp.asarray(y), jnp.asarray(groups)),
batch_size=self.batch_size,
seed=self.random_state,
)
for _, (X_batch, y_batch, groups_batch) in enumerate(
islice(dataloader, self.iters)
):
train_step(model, optimizer, metrics, X_batch, y_batch, groups_batch)
self.enc_ = model.enc
return self