"""Utility functions for neural networks in JAX."""fromcollections.abcimportGeneratorimportjaxfromjaximportArrayfromfair_forge.utilsimportbatched__all__=["grad_reverse","iterate_forever"]@jax.custom_vjpdefgrad_reverse(x:Array,lambda_:float)->Array:"""Gradient reversal layer for JAX."""returnxdef_grad_reverse_fwd(x:Array,lambda_:float)->tuple[Array,tuple[float]]:# Forward pass: just return x and save lambda_ for backwardreturnx,(lambda_,)def_grad_reverse_bwd(res:tuple[float],g:Array)->tuple[Array,None]:# Backward pass: reverse and scale the gradient(lambda_,)=resreturn(-g*lambda_,None)# None for lambda_ grad# Register the custom VJPgrad_reverse.defvjp(_grad_reverse_fwd,_grad_reverse_bwd)
[docs]defiterate_forever[T:Array,*S](data:tuple[T,*S],*,batch_size:int,seed:int=0,)->Generator[tuple[T,*S],None,None]:"""Yield batches of the data tuple forever. Use `itertools.islice` to limit the number of batches. """elem=data[0]assertall(d.shape[0]==elem.shape[0]fordindata),(# type: ignore"All elements of data must have the same first dimension.")assertbatch_size>0,"Batch size must be greater than 0."assertbatch_size<=elem.shape[0],("Batch size must be less than or equal to the number of samples.")len_data=elem.shape[0]key=jax.random.key(seed)whileTrue:# First generate shuffled indices.key,subkey=jax.random.split(key)shuffled_indices=jax.random.permutation(subkey,len_data)# Then yield the data in batches.forslice_inbatched(len_data,batch_size,drop_last=True):batch_indices=shuffled_indices[slice_]yieldtuple(d[batch_indices]fordindata)# type: ignore