fair_forge.nn.utils

Utility functions for neural networks in JAX.

Functions

iterate_forever(data, *, batch_size[, seed])

Yield batches of the data tuple forever.

fair_forge.nn.utils.iterate_forever(data: tuple[T, Unpack[S]], *, batch_size: int, seed: int = 0) Generator[tuple[T, Unpack[S]], None, None][source]

Yield batches of the data tuple forever.

Use itertools.islice to limit the number of batches.