def _easyrec_metrics_2_tzrec_metrics()

in tzrec/tools/convert_easyrec_config_to_tzrec_config.py [0:0]


    def _easyrec_metrics_2_tzrec_metrics(self, easyrec_metric):
        """Convert easyrec metric to tzrec metric."""
        metric = metric_pb2.MetricConfig()
        metric_type = easyrec_metric.WhichOneof("metric")
        easyrec_metric_ob = getattr(easyrec_metric, metric_type)
        if metric_type == "auc":
            metric.auc.CopyFrom(metric_pb2.AUC())
        elif metric_type == "gauc":
            tzrec_metric_ob = metric_pb2.GroupedAUC(
                grouping_key=easyrec_metric_ob.uid_field
            )
            metric.grouped_auc.CopyFrom(tzrec_metric_ob)
        elif metric_type == "recall_at_topk":
            metric.recall_at_k.CopyFrom(metric_pb2.RecallAtK())
        elif metric_type == "mean_absolute_error":
            metric.mean_absolute_error.CopyFrom(metric_pb2.MeanAbsoluteError())
        elif metric_type == "mean_squared_error":
            metric.mean_squared_error.CopyFrom(metric_pb2.MeanSquaredError())
        elif metric_type == "accuracy":
            metric.accuracy.CopyFrom(metric_pb2.Accuracy())
        else:
            logger.error(
                f"{easyrec_metric} is not convert to tzrec metric, please adaptation"
            )
        return metric