in tabular/src/autogluon/tabular/models/catboost/catboost_utils.py [0:0]
def construct_custom_catboost_metric(metric, is_higher_better, needs_pred_proba, problem_type):
if problem_type == SOFTCLASS:
from .catboost_softclass_utils import SoftclassCustomMetric
if metric.name != 'soft_log_loss':
logger.warning("Setting metric=soft_log_loss, the only metric supported for softclass problem_type")
# SoftclassCustomMetric = make_softclass_metric() # TODO: remove after catboost 0.24
return SoftclassCustomMetric(metric=None, is_higher_better=True, needs_pred_proba=True)
if (metric.name == 'log_loss') and (problem_type == MULTICLASS) and needs_pred_proba:
return 'MultiClass'
if metric.name == 'accuracy':
return 'Accuracy'
if (metric.name == 'log_loss') and (problem_type == BINARY) and needs_pred_proba:
return 'Logloss'
if (metric.name == 'roc_auc') and (problem_type == BINARY) and needs_pred_proba:
return 'AUC'
if (metric.name == 'roc_auc_ovo_macro') and (problem_type == MULTICLASS) and needs_pred_proba:
logger.warning(f'Metric {metric.name} is not supported by this model - using AUC:type=Mu instead')
return 'AUC:type=Mu'
if (metric.name in ['f1', 'f1_macro', 'f1_micro', 'f1_weighted']) and (problem_type == BINARY) and not needs_pred_proba:
return 'F1:hints=skip_train~true'
if (metric.name == 'balanced_accuracy') and (problem_type == BINARY) and not needs_pred_proba:
return 'BalancedAccuracy'
if (metric.name in ['recall', 'recall_macro', 'recall_micro', 'recall_weighted']) and (problem_type == BINARY) and not needs_pred_proba:
return 'Recall'
if (metric.name in ['precision', 'precision_macro', 'precision_micro', 'precision_weighted']) and (problem_type == BINARY) and not needs_pred_proba:
return 'Precision'
if (metric.name == 'mean_absolute_error') and (problem_type == REGRESSION):
return 'MAE'
if (metric.name in ['mean_squared_error', 'root_mean_squared_error']) and (problem_type == REGRESSION):
return 'RMSE'
if (metric.name == 'median_absolute_error') and (problem_type == REGRESSION):
return 'MedianAbsoluteError'
if (metric.name == 'r2') and (problem_type == REGRESSION):
return 'R2'
metric_class = metric_classes_dict[problem_type]
return metric_class(metric=metric, is_higher_better=is_higher_better, needs_pred_proba=needs_pred_proba)