def compute_relations()

in relogic/pretrainkit/models/relationalsemparse/relational_transformer.py [0:0]


  def compute_relations(self, desc, enc_length, q_enc_length, c_enc_length, c_boundaries, t_boundaries):
    sc_link = desc.get('sc_link', {'q_col_match': {}, 'q_tab_match': {}})
    cv_link = desc.get('cv_link', {'num_date_match': {}, 'cell_match': {}})

    # Catalogue which things are where
    loc_types = {}
    for i in range(q_enc_length):
      loc_types[i] = ('question',)

    c_base = q_enc_length
    for c_id, (c_start, c_end) in enumerate(zip(c_boundaries, c_boundaries[1:])):
      for i in range(c_start + c_base, c_end + c_base):
        loc_types[i] = ('column', c_id)
    t_base = q_enc_length + c_enc_length
    for t_id, (t_start, t_end) in enumerate(zip(t_boundaries, t_boundaries[1:])):
      for i in range(t_start + t_base, t_end + t_base):
        loc_types[i] = ('table', t_id)

    relations = np.empty((enc_length, enc_length), dtype=np.int64)

    for i, j in itertools.product(range(enc_length), repeat=2):
      def set_relation(name):
        relations[i, j] = self.relation_ids[name]

      i_type, j_type = loc_types[i], loc_types[j]
      if i_type[0] == 'question':
        if j_type[0] == 'question':
          set_relation(('qq_dist', clamp(j - i, self.qq_max_dist)))
        elif j_type[0] == 'column':
          # set_relation('qc_default')
          j_real = j - c_base
          if f"{i},{j_real}" in sc_link["q_col_match"]:
            set_relation("qc" + sc_link["q_col_match"][f"{i},{j_real}"])
          elif f"{i},{j_real}" in cv_link["cell_match"]:
            set_relation("qc" + cv_link["cell_match"][f"{i},{j_real}"])
          elif f"{i},{j_real}" in cv_link["num_date_match"]:
            set_relation("qc" + cv_link["num_date_match"][f"{i},{j_real}"])
          else:
            set_relation('qc_default')
        elif j_type[0] == 'table':
          # set_relation('qt_default')
          j_real = j - t_base
          if f"{i},{j_real}" in sc_link["q_tab_match"]:
            set_relation("qt" + sc_link["q_tab_match"][f"{i},{j_real}"])
          else:
            set_relation('qt_default')

      elif i_type[0] == 'column':
        if j_type[0] == 'question':
          # set_relation('cq_default')
          i_real = i - c_base
          if f"{j},{i_real}" in sc_link["q_col_match"]:
            set_relation("cq" + sc_link["q_col_match"][f"{j},{i_real}"])
          elif f"{j},{i_real}" in cv_link["cell_match"]:
            set_relation("cq" + cv_link["cell_match"][f"{j},{i_real}"])
          elif f"{j},{i_real}" in cv_link["num_date_match"]:
            set_relation("cq" + cv_link["num_date_match"][f"{j},{i_real}"])
          else:
            set_relation('cq_default')
        elif j_type[0] == 'column':
          col1, col2 = i_type[1], j_type[1]
          if col1 == col2:
            set_relation(('cc_dist', clamp(j - i, self.cc_max_dist)))
          else:
            set_relation('cc_default')
            if self.cc_foreign_key:
              if desc['foreign_keys'].get(str(col1)) == col2:
                set_relation('cc_foreign_key_forward')
              if desc['foreign_keys'].get(str(col2)) == col1:
                set_relation('cc_foreign_key_backward')
            if (self.cc_table_match and
                  desc['column_to_table'][str(col1)] == desc['column_to_table'][str(col2)]):
              set_relation('cc_table_match')

        elif j_type[0] == 'table':
          col, table = i_type[1], j_type[1]
          set_relation('ct_default')
          if self.ct_foreign_key and self.match_foreign_key(desc, col, table):
            set_relation('ct_foreign_key')
          if self.ct_table_match:
            col_table = desc['column_to_table'][str(col)]
            if col_table == table:
              if col in desc['primary_keys']:
                set_relation('ct_primary_key')
              else:
                set_relation('ct_table_match')
            elif col_table is None:
              set_relation('ct_any_table')

      elif i_type[0] == 'table':
        if j_type[0] == 'question':
          # set_relation('tq_default')
          i_real = i - t_base
          if f"{j},{i_real}" in sc_link["q_tab_match"]:
            set_relation("tq" + sc_link["q_tab_match"][f"{j},{i_real}"])
          else:
            set_relation('tq_default')
        elif j_type[0] == 'column':
          table, col = i_type[1], j_type[1]
          set_relation('tc_default')

          if self.tc_foreign_key and self.match_foreign_key(desc, col, table):
            set_relation('tc_foreign_key')
          if self.tc_table_match:
            col_table = desc['column_to_table'][str(col)]
            if col_table == table:
              if col in desc['primary_keys']:
                set_relation('tc_primary_key')
              else:
                set_relation('tc_table_match')
            elif col_table is None:
              set_relation('tc_any_table')
        elif j_type[0] == 'table':
          table1, table2 = i_type[1], j_type[1]
          if table1 == table2:
            set_relation(('tt_dist', clamp(j - i, self.tt_max_dist)))
          else:
            set_relation('tt_default')
            if self.tt_foreign_key:
              forward = table2 in desc['foreign_keys_tables'].get(str(table1), ())
              backward = table1 in desc['foreign_keys_tables'].get(str(table2), ())
              if forward and backward:
                set_relation('tt_foreign_key_both')
              elif forward:
                set_relation('tt_foreign_key_forward')
              elif backward:
                set_relation('tt_foreign_key_backward')
    return relations