[docs]classSplit(Enum):"""Enum for different split methods used in evaluation."""BASIC="basic"PROPORTIONAL="proportional"
[docs]defevaluate(dataset:GroupDataset,methods:Mapping[str,Method|GroupMethod],metrics:Sequence[Metric],group_metrics:Sequence[GroupMetric],*,preprocessor:Preprocessor|None=None,repeat:int=1,split:Split|SplitMethod=Split.PROPORTIONAL,seed:int=42,train_percentage:float=0.8,remove_score_suffix:bool=True,seed_methods:bool=True,)->pl.DataFrame:"""Evaluate methods on a dataset using specified metrics and group metrics."""result:list[dict[str,Any]]=[]forrepeat_indexinrange(repeat):split_seed=seed+repeat_indexsplit_method:SplitMethodmatchsplit:caseSplit.BASIC:split_method=basic_splitcaseSplit.PROPORTIONAL:split_method=proportional_splitcase_:split_method=splittrain_idx,test_idx=split_method(split_seed,train_percentage,target=dataset.target,groups=dataset.groups)train_x=dataset.data[train_idx]train_y=dataset.target[train_idx]train_groups=dataset.groups[train_idx]test_x=dataset.data[test_idx]test_y=dataset.target[test_idx]test_groups=dataset.groups[test_idx]ifpreprocessorisnotNone:train_x=preprocessor.fit(train_x,train_y).transform(train_x)test_x=preprocessor.transform(test_x)formethod_name,methodinmethods.items():row:dict[str,Any]={}row["method"]=method_namerow["repeat_index"]=repeat_indexrow["split_seed"]=split_seedifseed_methodsand"random_state"inmethod.get_params():# If the method has a `random_state` parameter, we set it.method.set_params(random_state=split_seed)# If a method requests `groups` in its metadata, we cast it to GroupMethod.if"groups"inmethod.get_metadata_routing().fit.requests:cast(GroupMethod,method).fit(train_x,train_y,groups=train_groups,)else:cast(Method,method).fit(train_x,train_y)predictions=method.predict(test_x)formetricinmetrics:metric_name=metric.__name__ifremove_score_suffixandmetric_name.endswith("_score"):metric_name=metric_name[:-6]score=metric(y_true=test_y,y_pred=predictions)row[metric_name]=scoreforgroup_metricingroup_metrics:group_metric_name=group_metric.__name__group_score=group_metric(y_true=test_y,y_pred=predictions,groups=test_groups)row[group_metric_name]=group_scoreresult.append(row)# Convert the result list to a Polars DataFrame.# We use `pl.Enum` to ensure the correct ordering of method names.method_names=pl.Enum(list(methods))returnpl.DataFrame(result,schema_overrides={"method":method_names,"repeat_index":pl.Int64,"split_seed":pl.Int64,},).sort("method","repeat_index")