in tzrec/tools/convert_easyrec_config_to_tzrec_config.py [0:0]
def _easyrec_metrics_2_tzrec_metrics(self, easyrec_metric):
"""Convert easyrec metric to tzrec metric."""
metric = metric_pb2.MetricConfig()
metric_type = easyrec_metric.WhichOneof("metric")
easyrec_metric_ob = getattr(easyrec_metric, metric_type)
if metric_type == "auc":
metric.auc.CopyFrom(metric_pb2.AUC())
elif metric_type == "gauc":
tzrec_metric_ob = metric_pb2.GroupedAUC(
grouping_key=easyrec_metric_ob.uid_field
)
metric.grouped_auc.CopyFrom(tzrec_metric_ob)
elif metric_type == "recall_at_topk":
metric.recall_at_k.CopyFrom(metric_pb2.RecallAtK())
elif metric_type == "mean_absolute_error":
metric.mean_absolute_error.CopyFrom(metric_pb2.MeanAbsoluteError())
elif metric_type == "mean_squared_error":
metric.mean_squared_error.CopyFrom(metric_pb2.MeanSquaredError())
elif metric_type == "accuracy":
metric.accuracy.CopyFrom(metric_pb2.Accuracy())
else:
logger.error(
f"{easyrec_metric} is not convert to tzrec metric, please adaptation"
)
return metric