in rat-sql-gap/seq2struct/models/spider/spider_match_utils.py [0:0]
def compute_schema_linking(question, column, table):
def partial_match(x_list, y_list):
x_str = " ".join(x_list)
y_str = " ".join(y_list)
if x_str in STOPWORDS or x_str in PUNKS:
return False
if re.match(rf"\b{re.escape(x_str)}\b", y_str):
assert x_str in y_str
return True
else:
return False
def exact_match(x_list, y_list):
x_str = " ".join(x_list)
y_str = " ".join(y_list)
if x_str == y_str:
return True
else:
return False
q_col_match = dict()
q_tab_match = dict()
col_id2list = dict()
for col_id, col_item in enumerate(column):
if col_id == 0:
continue
col_id2list[col_id] = col_item
tab_id2list = dict()
for tab_id, tab_item in enumerate(table):
tab_id2list[tab_id] = tab_item
# 5-gram
n = 5
while n > 0:
for i in range(len(question) - n + 1):
n_gram_list = question[i:i + n]
n_gram = " ".join(n_gram_list)
if len(n_gram.strip()) == 0:
continue
# exact match case
for col_id in col_id2list:
if exact_match(n_gram_list, col_id2list[col_id]):
for q_id in range(i, i+n):
q_col_match[f"{q_id},{col_id}"] = "CEM"
for tab_id in tab_id2list:
if exact_match(n_gram_list, tab_id2list[tab_id]):
for q_id in range(i, i+n):
q_tab_match[f"{q_id},{tab_id}"] = "TEM"
# partial match case
for col_id in col_id2list:
if partial_match(n_gram_list, col_id2list[col_id]):
for q_id in range(i, i+n):
if f"{q_id},{col_id}" not in q_col_match:
q_col_match[f"{q_id},{col_id}"] = "CPM"
for tab_id in tab_id2list:
if partial_match(n_gram_list, tab_id2list[tab_id]):
for q_id in range(i, i+n):
if f"{q_id},{tab_id}" not in q_tab_match:
q_tab_match[f"{q_id},{tab_id}"] = "TPM"
n -= 1
return {"q_col_match": q_col_match, "q_tab_match": q_tab_match }