"""Utility functions for Fair Forge."""fromcollections.abcimportGeneratorimportnumpyasnp__all__=["batched","reproducible_random_state"]
[docs]defreproducible_random_state(seed:int)->np.random.Generator:"""Create a random state that is reproducible across Python versions and platforms."""# MT19937 isn't the best random number generator, but it's reproducible, so we're using it.returnnp.random.Generator(np.random.MT19937(seed))
[docs]defbatched(len_data:int,batch_size:int,*,drop_last:bool=False)->Generator[slice,None,None]:"""Yield slices of indices for batching data. Args: len_data: The total number of data points. batch_size: The size of each batch. drop_last: If True, the last batch will be dropped if it is smaller than batch_size. """forstartinrange(0,len_data,batch_size):end=start+batch_sizeifend>len_data:ifdrop_last:# If the last batch is smaller than batch_size, we skip it.breakelse:end=len_datayieldslice(start,end)