rat-sql-gap/seq2struct/models/spider/spider_enc_modules.py [312:449]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        self.num_heads = num_heads

        self.qq_max_dist    = qq_max_dist
        #self.qc_token_match = qc_token_match
        #self.qt_token_match = qt_token_match
        #self.cq_token_match = cq_token_match
        self.cc_foreign_key = cc_foreign_key
        self.cc_table_match = cc_table_match
        self.cc_max_dist    = cc_max_dist
        self.ct_foreign_key = ct_foreign_key
        self.ct_table_match = ct_table_match
        #self.tq_token_match = tq_token_match
        self.tc_table_match = tc_table_match
        self.tc_foreign_key = tc_foreign_key
        self.tt_max_dist    = tt_max_dist
        self.tt_foreign_key = tt_foreign_key

        self.relation_ids = {}
        def add_relation(name):
            self.relation_ids[name] = len(self.relation_ids)
        def add_rel_dist(name, max_dist):
            for i in range(-max_dist, max_dist + 1):
                add_relation((name, i))

        add_rel_dist('qq_dist', qq_max_dist)

        add_relation('qc_default')
        #if qc_token_match:
        #    add_relation('qc_token_match')

        add_relation('qt_default')
        #if qt_token_match:
        #    add_relation('qt_token_match')

        add_relation('cq_default')
        #if cq_token_match:
        #    add_relation('cq_token_match')

        add_relation('cc_default')
        if cc_foreign_key:
            add_relation('cc_foreign_key_forward')
            add_relation('cc_foreign_key_backward')
        if cc_table_match:
            add_relation('cc_table_match')
        add_rel_dist('cc_dist', cc_max_dist)

        add_relation('ct_default')
        if ct_foreign_key:
            add_relation('ct_foreign_key')
        if ct_table_match:
            add_relation('ct_primary_key')
            add_relation('ct_table_match')
            add_relation('ct_any_table')

        add_relation('tq_default')
        #if cq_token_match:
        #    add_relation('tq_token_match')

        add_relation('tc_default')
        if tc_table_match:
            add_relation('tc_primary_key')
            add_relation('tc_table_match')
            add_relation('tc_any_table')
        if tc_foreign_key:
            add_relation('tc_foreign_key')

        add_relation('tt_default')
        if tt_foreign_key:
            add_relation('tt_foreign_key_forward')
            add_relation('tt_foreign_key_backward')
            add_relation('tt_foreign_key_both')
        add_rel_dist('tt_dist', tt_max_dist)

        # schema linking relations
        # forward_backward
        if sc_link:
            add_relation('qcCEM')
            add_relation('cqCEM')
            add_relation('qtTEM')
            add_relation('tqTEM')
            add_relation('qcCPM')
            add_relation('cqCPM')
            add_relation('qtTPM')
            add_relation('tqTPM')
        
        if cv_link:
            add_relation("qcNUMBER")
            add_relation("cqNUMBER")
            add_relation("qcTIME")
            add_relation("cqTIME")
            add_relation("qcCELLMATCH")
            add_relation("cqCELLMATCH")

        if merge_types:
            assert not cc_foreign_key
            assert not cc_table_match
            assert not ct_foreign_key
            assert not ct_table_match
            assert not tc_foreign_key
            assert not tc_table_match
            assert not tt_foreign_key

            assert cc_max_dist == qq_max_dist
            assert tt_max_dist == qq_max_dist

            add_relation('xx_default')
            self.relation_ids['qc_default'] = self.relation_ids['xx_default']
            self.relation_ids['qt_default'] = self.relation_ids['xx_default']
            self.relation_ids['cq_default'] = self.relation_ids['xx_default']
            self.relation_ids['cc_default'] = self.relation_ids['xx_default']
            self.relation_ids['ct_default'] = self.relation_ids['xx_default']
            self.relation_ids['tq_default'] = self.relation_ids['xx_default']
            self.relation_ids['tc_default'] = self.relation_ids['xx_default']
            self.relation_ids['tt_default'] = self.relation_ids['xx_default']

            if sc_link:
                self.relation_ids['qcCEM'] = self.relation_ids['xx_default']
                self.relation_ids['qcCPM'] = self.relation_ids['xx_default']
                self.relation_ids['qtTEM'] = self.relation_ids['xx_default']
                self.relation_ids['qtTPM'] = self.relation_ids['xx_default']
                self.relation_ids['cqCEM'] = self.relation_ids['xx_default']
                self.relation_ids['cqCPM'] = self.relation_ids['xx_default']
                self.relation_ids['tqTEM'] = self.relation_ids['xx_default']
                self.relation_ids['tqTPM'] = self.relation_ids['xx_default']
            if cv_link:
                self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default']
                self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default']
                self.relation_ids["qcTIME"] = self.relation_ids['xx_default']
                self.relation_ids["cqTIME"] = self.relation_ids['xx_default']
                self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default']
                self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default']

            for i in range(-qq_max_dist, qq_max_dist + 1):
                self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i]
                self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i]

        if ff_size is None:
            ff_size = hidden_size * 4
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



relogic/pretrainkit/models/relationalsemparse/relational_transformer.py [414:553]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    self.num_heads = num_heads

    self.qq_max_dist = qq_max_dist
    # self.qc_token_match = qc_token_match
    # self.qt_token_match = qt_token_match
    # self.cq_token_match = cq_token_match
    self.cc_foreign_key = cc_foreign_key
    self.cc_table_match = cc_table_match
    self.cc_max_dist = cc_max_dist
    self.ct_foreign_key = ct_foreign_key
    self.ct_table_match = ct_table_match
    # self.tq_token_match = tq_token_match
    self.tc_table_match = tc_table_match
    self.tc_foreign_key = tc_foreign_key
    self.tt_max_dist = tt_max_dist
    self.tt_foreign_key = tt_foreign_key

    self.relation_ids = {}

    def add_relation(name):
      self.relation_ids[name] = len(self.relation_ids)

    def add_rel_dist(name, max_dist):
      for i in range(-max_dist, max_dist + 1):
        add_relation((name, i))

    add_rel_dist('qq_dist', qq_max_dist)

    add_relation('qc_default')
    # if qc_token_match:
    #    add_relation('qc_token_match')

    add_relation('qt_default')
    # if qt_token_match:
    #    add_relation('qt_token_match')

    add_relation('cq_default')
    # if cq_token_match:
    #    add_relation('cq_token_match')

    add_relation('cc_default')
    if cc_foreign_key:
      add_relation('cc_foreign_key_forward')
      add_relation('cc_foreign_key_backward')
    if cc_table_match:
      add_relation('cc_table_match')
    add_rel_dist('cc_dist', cc_max_dist)

    add_relation('ct_default')
    if ct_foreign_key:
      add_relation('ct_foreign_key')
    if ct_table_match:
      add_relation('ct_primary_key')
      add_relation('ct_table_match')
      add_relation('ct_any_table')

    add_relation('tq_default')
    # if cq_token_match:
    #    add_relation('tq_token_match')

    add_relation('tc_default')
    if tc_table_match:
      add_relation('tc_primary_key')
      add_relation('tc_table_match')
      add_relation('tc_any_table')
    if tc_foreign_key:
      add_relation('tc_foreign_key')

    add_relation('tt_default')
    if tt_foreign_key:
      add_relation('tt_foreign_key_forward')
      add_relation('tt_foreign_key_backward')
      add_relation('tt_foreign_key_both')
    add_rel_dist('tt_dist', tt_max_dist)

    # schema linking relations
    # forward_backward
    if sc_link:
      add_relation('qcCEM')
      add_relation('cqCEM')
      add_relation('qtTEM')
      add_relation('tqTEM')
      add_relation('qcCPM')
      add_relation('cqCPM')
      add_relation('qtTPM')
      add_relation('tqTPM')

    if cv_link:
      add_relation("qcNUMBER")
      add_relation("cqNUMBER")
      add_relation("qcTIME")
      add_relation("cqTIME")
      add_relation("qcCELLMATCH")
      add_relation("cqCELLMATCH")

    if merge_types:
      assert not cc_foreign_key
      assert not cc_table_match
      assert not ct_foreign_key
      assert not ct_table_match
      assert not tc_foreign_key
      assert not tc_table_match
      assert not tt_foreign_key

      assert cc_max_dist == qq_max_dist
      assert tt_max_dist == qq_max_dist

      add_relation('xx_default')
      self.relation_ids['qc_default'] = self.relation_ids['xx_default']
      self.relation_ids['qt_default'] = self.relation_ids['xx_default']
      self.relation_ids['cq_default'] = self.relation_ids['xx_default']
      self.relation_ids['cc_default'] = self.relation_ids['xx_default']
      self.relation_ids['ct_default'] = self.relation_ids['xx_default']
      self.relation_ids['tq_default'] = self.relation_ids['xx_default']
      self.relation_ids['tc_default'] = self.relation_ids['xx_default']
      self.relation_ids['tt_default'] = self.relation_ids['xx_default']

      if sc_link:
        self.relation_ids['qcCEM'] = self.relation_ids['xx_default']
        self.relation_ids['qcCPM'] = self.relation_ids['xx_default']
        self.relation_ids['qtTEM'] = self.relation_ids['xx_default']
        self.relation_ids['qtTPM'] = self.relation_ids['xx_default']
        self.relation_ids['cqCEM'] = self.relation_ids['xx_default']
        self.relation_ids['cqCPM'] = self.relation_ids['xx_default']
        self.relation_ids['tqTEM'] = self.relation_ids['xx_default']
        self.relation_ids['tqTPM'] = self.relation_ids['xx_default']
      if cv_link:
        self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default']
        self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default']
        self.relation_ids["qcTIME"] = self.relation_ids['xx_default']
        self.relation_ids["cqTIME"] = self.relation_ids['xx_default']
        self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default']
        self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default']

      for i in range(-qq_max_dist, qq_max_dist + 1):
        self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i]
        self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i]

    if ff_size is None:
      ff_size = hidden_size * 4
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



