def __init__()

in relogic/pretrainkit/models/relationalsemparse/relational_transformer.py [0:0]


  def __init__(self, num_layers, num_heads, hidden_size,
               ff_size=None,
               dropout=0.1,
               merge_types=False,
               tie_layers=False,
               qq_max_dist=2,
               # qc_token_match=True,
               # qt_token_match=True,
               # cq_token_match=True,
               cc_foreign_key=True,
               cc_table_match=True,
               cc_max_dist=2,
               ct_foreign_key=True,
               ct_table_match=True,
               # tq_token_match=True,
               tc_table_match=True,
               tc_foreign_key=True,
               tt_max_dist=2,
               tt_foreign_key=True,
               sc_link=False,
               cv_link=False,
               ):
    super().__init__()
    self.num_heads = num_heads

    self.qq_max_dist = qq_max_dist
    # self.qc_token_match = qc_token_match
    # self.qt_token_match = qt_token_match
    # self.cq_token_match = cq_token_match
    self.cc_foreign_key = cc_foreign_key
    self.cc_table_match = cc_table_match
    self.cc_max_dist = cc_max_dist
    self.ct_foreign_key = ct_foreign_key
    self.ct_table_match = ct_table_match
    # self.tq_token_match = tq_token_match
    self.tc_table_match = tc_table_match
    self.tc_foreign_key = tc_foreign_key
    self.tt_max_dist = tt_max_dist
    self.tt_foreign_key = tt_foreign_key

    self.relation_ids = {}

    def add_relation(name):
      self.relation_ids[name] = len(self.relation_ids)

    def add_rel_dist(name, max_dist):
      for i in range(-max_dist, max_dist + 1):
        add_relation((name, i))

    add_rel_dist('qq_dist', qq_max_dist)

    add_relation('qc_default')
    # if qc_token_match:
    #    add_relation('qc_token_match')

    add_relation('qt_default')
    # if qt_token_match:
    #    add_relation('qt_token_match')

    add_relation('cq_default')
    # if cq_token_match:
    #    add_relation('cq_token_match')

    add_relation('cc_default')
    if cc_foreign_key:
      add_relation('cc_foreign_key_forward')
      add_relation('cc_foreign_key_backward')
    if cc_table_match:
      add_relation('cc_table_match')
    add_rel_dist('cc_dist', cc_max_dist)

    add_relation('ct_default')
    if ct_foreign_key:
      add_relation('ct_foreign_key')
    if ct_table_match:
      add_relation('ct_primary_key')
      add_relation('ct_table_match')
      add_relation('ct_any_table')

    add_relation('tq_default')
    # if cq_token_match:
    #    add_relation('tq_token_match')

    add_relation('tc_default')
    if tc_table_match:
      add_relation('tc_primary_key')
      add_relation('tc_table_match')
      add_relation('tc_any_table')
    if tc_foreign_key:
      add_relation('tc_foreign_key')

    add_relation('tt_default')
    if tt_foreign_key:
      add_relation('tt_foreign_key_forward')
      add_relation('tt_foreign_key_backward')
      add_relation('tt_foreign_key_both')
    add_rel_dist('tt_dist', tt_max_dist)

    # schema linking relations
    # forward_backward
    if sc_link:
      add_relation('qcCEM')
      add_relation('cqCEM')
      add_relation('qtTEM')
      add_relation('tqTEM')
      add_relation('qcCPM')
      add_relation('cqCPM')
      add_relation('qtTPM')
      add_relation('tqTPM')

    if cv_link:
      add_relation("qcNUMBER")
      add_relation("cqNUMBER")
      add_relation("qcTIME")
      add_relation("cqTIME")
      add_relation("qcCELLMATCH")
      add_relation("cqCELLMATCH")

    if merge_types:
      assert not cc_foreign_key
      assert not cc_table_match
      assert not ct_foreign_key
      assert not ct_table_match
      assert not tc_foreign_key
      assert not tc_table_match
      assert not tt_foreign_key

      assert cc_max_dist == qq_max_dist
      assert tt_max_dist == qq_max_dist

      add_relation('xx_default')
      self.relation_ids['qc_default'] = self.relation_ids['xx_default']
      self.relation_ids['qt_default'] = self.relation_ids['xx_default']
      self.relation_ids['cq_default'] = self.relation_ids['xx_default']
      self.relation_ids['cc_default'] = self.relation_ids['xx_default']
      self.relation_ids['ct_default'] = self.relation_ids['xx_default']
      self.relation_ids['tq_default'] = self.relation_ids['xx_default']
      self.relation_ids['tc_default'] = self.relation_ids['xx_default']
      self.relation_ids['tt_default'] = self.relation_ids['xx_default']

      if sc_link:
        self.relation_ids['qcCEM'] = self.relation_ids['xx_default']
        self.relation_ids['qcCPM'] = self.relation_ids['xx_default']
        self.relation_ids['qtTEM'] = self.relation_ids['xx_default']
        self.relation_ids['qtTPM'] = self.relation_ids['xx_default']
        self.relation_ids['cqCEM'] = self.relation_ids['xx_default']
        self.relation_ids['cqCPM'] = self.relation_ids['xx_default']
        self.relation_ids['tqTEM'] = self.relation_ids['xx_default']
        self.relation_ids['tqTPM'] = self.relation_ids['xx_default']
      if cv_link:
        self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default']
        self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default']
        self.relation_ids["qcTIME"] = self.relation_ids['xx_default']
        self.relation_ids["cqTIME"] = self.relation_ids['xx_default']
        self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default']
        self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default']

      for i in range(-qq_max_dist, qq_max_dist + 1):
        self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i]
        self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i]

    if ff_size is None:
      ff_size = hidden_size * 4
    self.encoder = Encoder(
      lambda: EncoderLayer(
        hidden_size,
        MultiHeadedAttentionWithRelations(
          num_heads,
          hidden_size,
          dropout),
        PositionwiseFeedForward(
          hidden_size,
          ff_size,
          dropout),
        len(self.relation_ids),
        dropout),
      hidden_size,
      num_layers,
      tie_layers)

    self.align_attn = PointerWithRelations(hidden_size,
                                                       len(self.relation_ids), dropout)