in rat-sql-gap/seq2struct/models/spider/spider_enc_modules.py [0:0]
def compute_relations(self, desc, enc_length, q_enc_length, c_enc_length, c_boundaries, t_boundaries):
sc_link = desc.get('sc_link', {'q_col_match': {}, 'q_tab_match': {}})
cv_link = desc.get('cv_link', {'num_date_match': {}, 'cell_match': {}})
# Catalogue which things are where
loc_types = {}
for i in range(q_enc_length):
loc_types[i] = ('question',)
c_base = q_enc_length
for c_id, (c_start, c_end) in enumerate(zip(c_boundaries, c_boundaries[1:])):
for i in range(c_start + c_base, c_end + c_base):
loc_types[i] = ('column', c_id)
t_base = q_enc_length + c_enc_length
for t_id, (t_start, t_end) in enumerate(zip(t_boundaries, t_boundaries[1:])):
for i in range(t_start + t_base, t_end + t_base):
loc_types[i] = ('table', t_id)
relations = np.empty((enc_length, enc_length), dtype=np.int64)
for i, j in itertools.product(range(enc_length),repeat=2):
def set_relation(name):
relations[i, j] = self.relation_ids[name]
i_type, j_type = loc_types[i], loc_types[j]
if i_type[0] == 'question':
if j_type[0] == 'question':
set_relation(('qq_dist', clamp(j - i, self.qq_max_dist)))
elif j_type[0] == 'column':
# set_relation('qc_default')
j_real = j - c_base
if f"{i},{j_real}" in sc_link["q_col_match"]:
set_relation("qc" + sc_link["q_col_match"][f"{i},{j_real}"])
elif f"{i},{j_real}" in cv_link["cell_match"]:
set_relation("qc" + cv_link["cell_match"][f"{i},{j_real}"])
elif f"{i},{j_real}" in cv_link["num_date_match"]:
set_relation("qc" + cv_link["num_date_match"][f"{i},{j_real}"])
else:
set_relation('qc_default')
elif j_type[0] == 'table':
# set_relation('qt_default')
j_real = j - t_base
if f"{i},{j_real}" in sc_link["q_tab_match"]:
set_relation("qt" + sc_link["q_tab_match"][f"{i},{j_real}"])
else:
set_relation('qt_default')
elif i_type[0] == 'column':
if j_type[0] == 'question':
# set_relation('cq_default')
i_real = i - c_base
if f"{j},{i_real}" in sc_link["q_col_match"]:
set_relation("cq" + sc_link["q_col_match"][f"{j},{i_real}"])
elif f"{j},{i_real}" in cv_link["cell_match"]:
set_relation("cq" + cv_link["cell_match"][f"{j},{i_real}"])
elif f"{j},{i_real}" in cv_link["num_date_match"]:
set_relation("cq" + cv_link["num_date_match"][f"{j},{i_real}"])
else:
set_relation('cq_default')
elif j_type[0] == 'column':
col1, col2 = i_type[1], j_type[1]
if col1 == col2:
set_relation(('cc_dist', clamp(j - i, self.cc_max_dist)))
else:
set_relation('cc_default')
if self.cc_foreign_key:
if desc['foreign_keys'].get(str(col1)) == col2:
set_relation('cc_foreign_key_forward')
if desc['foreign_keys'].get(str(col2)) == col1:
set_relation('cc_foreign_key_backward')
if (self.cc_table_match and
desc['column_to_table'][str(col1)] == desc['column_to_table'][str(col2)]):
set_relation('cc_table_match')
elif j_type[0] == 'table':
col, table = i_type[1], j_type[1]
set_relation('ct_default')
if self.ct_foreign_key and self.match_foreign_key(desc, col, table):
set_relation('ct_foreign_key')
if self.ct_table_match:
col_table = desc['column_to_table'][str(col)]
if col_table == table:
if col in desc['primary_keys']:
set_relation('ct_primary_key')
else:
set_relation('ct_table_match')
elif col_table is None:
set_relation('ct_any_table')
elif i_type[0] == 'table':
if j_type[0] == 'question':
# set_relation('tq_default')
i_real = i - t_base
if f"{j},{i_real}" in sc_link["q_tab_match"]:
set_relation("tq" + sc_link["q_tab_match"][f"{j},{i_real}"])
else:
set_relation('tq_default')
elif j_type[0] == 'column':
table, col = i_type[1], j_type[1]
set_relation('tc_default')
if self.tc_foreign_key and self.match_foreign_key(desc, col, table):
set_relation('tc_foreign_key')
if self.tc_table_match:
col_table = desc['column_to_table'][str(col)]
if col_table == table:
if col in desc['primary_keys']:
set_relation('tc_primary_key')
else:
set_relation('tc_table_match')
elif col_table is None:
set_relation('tc_any_table')
elif j_type[0] == 'table':
table1, table2 = i_type[1], j_type[1]
if table1 == table2:
set_relation(('tt_dist', clamp(j - i, self.tt_max_dist)))
else:
set_relation('tt_default')
if self.tt_foreign_key:
forward = table2 in desc['foreign_keys_tables'].get(str(table1), ())
backward = table1 in desc['foreign_keys_tables'].get(str(table2), ())
if forward and backward:
set_relation('tt_foreign_key_both')
elif forward:
set_relation('tt_foreign_key_forward')
elif backward:
set_relation('tt_foreign_key_backward')
return relations