in src/engine/step4/model_dev/utils/metrics.py [0:0]
def execute_matching(tool, row, entities):
"""
Execute the true and predicted query of a give row.
Args:
tool(nlq2SqlTool): Tool with the functions for the query is be rendered and executed.
row: Row for a single query.
entities(dict): Dictionary of specific values for the query arguments.
Returns:
True if the inferred query can be executed and returns the same as the true query. Otherwise, returns False.
"""
base_question = row["base_question"]
unfolded_question = row["unfolded_questions"]
true_query = row["query"]
pred_query = row["preds_wiki"]
indx = int(row.name) + 1
# Render & execute targets
true_sql_query = tool.render_template_query(true_query, entities)
true_output = tool.execute_sql_query(true_sql_query)
# print(true_output.iloc[:5])
# Render & execute predictions
try:
pred_sql_query = tool.render_template_query(pred_query, entities)
pred_output = tool.execute_sql_query(pred_sql_query)
except:
pred_output = None
if pred_output is not None:
# Fix the reordering of the output columns
true_columns = true_output.columns.tolist()
pred_columns = pred_output.columns.tolist()
if set(true_columns) == set(pred_columns):
pred_output = pred_output[true_columns]
if pred_output is None:
return 0
true_first_c = true_output.columns[0]
pred_first_c = pred_output.columns[0]
true_output0 = true_output.sort_values(by=true_first_c).values
pred_output0 = pred_output.sort_values(by=pred_first_c).values
if true_query == pred_query or np.array_equal(true_output0, pred_output0):
return 1.0
else:
return 0.0