[docs]classSplitMethod(Protocol):"""Protocol for split methods."""
[docs]def__call__(self,seed:int,train_percentage:float,*,target:NDArray[np.int32],groups:NDArray[np.int32],)->tuple[NDArray[np.int64],NDArray[np.int64]]:"""Generate the indices of the train and test splits."""...
[docs]defbasic_split(seed:int,train_percentage:float,*,target:NDArray[np.int32],groups:NDArray[np.int32],)->tuple[NDArray[np.int64],NDArray[np.int64]]:"""Split the dataset into training and testing sets with a basic split."""length=len(target)train_size=round(length*train_percentage)indices=np.arange(length,dtype=np.int64)generator=reproducible_random_state(seed)generator.shuffle(indices)train_indices=indices[:train_size]test_indices=indices[train_size:]returntrain_indices,test_indices
[docs]defproportional_split(seed:int,train_percentage:float,*,target:NDArray[np.int32],groups:NDArray[np.int32],)->tuple[NDArray[np.int64],NDArray[np.int64]]:"""Generate the indices of the train and test splits using a proportional sampling scheme."""# local random state that won't affect the global states_vals:list[np.int32]=list(np.unique(groups))y_vals:list[np.int32]=list(np.unique(target))train_indices:list[NDArray[np.int64]]=[]test_indices:list[NDArray[np.int64]]=[]generator=reproducible_random_state(seed)# iterate over all combinations of s and yfors,yinitertools.product(s_vals,y_vals):# find all indices for this groupidx=np.nonzero((groups==s)&(target==y))[0]# shuffle and take subsetsgenerator.shuffle(idx)split_indices:int=round(len(idx)*train_percentage)# append index subsets to the list of train indicestrain_indices.append(idx[:split_indices])test_indices.append(idx[split_indices:])train_indices_=np.concatenate(train_indices,axis=0)test_indices_=np.concatenate(test_indices,axis=0)deltrain_indicesdeltest_indicesnum_groups=len(s_vals)*len(y_vals)expected_train_len=round(len(target)*train_percentage)# assert that we (at least approximately) achieved the specified `train_percentage`# the maximum error occurs when all the group splits favor train or all favor testassert(expected_train_len-num_groups<=len(train_indices_)<=expected_train_len+num_groups)returntrain_indices_,test_indices_