in rat-sql-gap/seq2struct/models/spider/spider_enc_modules.py [0:0]
def __init__(self, device, num_layers, num_heads, hidden_size,
ff_size=None,
dropout=0.1,
merge_types=False,
tie_layers=False,
qq_max_dist=2,
#qc_token_match=True,
#qt_token_match=True,
#cq_token_match=True,
cc_foreign_key=True,
cc_table_match=True,
cc_max_dist=2,
ct_foreign_key=True,
ct_table_match=True,
#tq_token_match=True,
tc_table_match=True,
tc_foreign_key=True,
tt_max_dist=2,
tt_foreign_key=True,
sc_link=False,
cv_link=False,
):
super().__init__()
self._device = device
self.num_heads = num_heads
self.qq_max_dist = qq_max_dist
#self.qc_token_match = qc_token_match
#self.qt_token_match = qt_token_match
#self.cq_token_match = cq_token_match
self.cc_foreign_key = cc_foreign_key
self.cc_table_match = cc_table_match
self.cc_max_dist = cc_max_dist
self.ct_foreign_key = ct_foreign_key
self.ct_table_match = ct_table_match
#self.tq_token_match = tq_token_match
self.tc_table_match = tc_table_match
self.tc_foreign_key = tc_foreign_key
self.tt_max_dist = tt_max_dist
self.tt_foreign_key = tt_foreign_key
self.relation_ids = {}
def add_relation(name):
self.relation_ids[name] = len(self.relation_ids)
def add_rel_dist(name, max_dist):
for i in range(-max_dist, max_dist + 1):
add_relation((name, i))
add_rel_dist('qq_dist', qq_max_dist)
add_relation('qc_default')
#if qc_token_match:
# add_relation('qc_token_match')
add_relation('qt_default')
#if qt_token_match:
# add_relation('qt_token_match')
add_relation('cq_default')
#if cq_token_match:
# add_relation('cq_token_match')
add_relation('cc_default')
if cc_foreign_key:
add_relation('cc_foreign_key_forward')
add_relation('cc_foreign_key_backward')
if cc_table_match:
add_relation('cc_table_match')
add_rel_dist('cc_dist', cc_max_dist)
add_relation('ct_default')
if ct_foreign_key:
add_relation('ct_foreign_key')
if ct_table_match:
add_relation('ct_primary_key')
add_relation('ct_table_match')
add_relation('ct_any_table')
add_relation('tq_default')
#if cq_token_match:
# add_relation('tq_token_match')
add_relation('tc_default')
if tc_table_match:
add_relation('tc_primary_key')
add_relation('tc_table_match')
add_relation('tc_any_table')
if tc_foreign_key:
add_relation('tc_foreign_key')
add_relation('tt_default')
if tt_foreign_key:
add_relation('tt_foreign_key_forward')
add_relation('tt_foreign_key_backward')
add_relation('tt_foreign_key_both')
add_rel_dist('tt_dist', tt_max_dist)
# schema linking relations
# forward_backward
if sc_link:
add_relation('qcCEM')
add_relation('cqCEM')
add_relation('qtTEM')
add_relation('tqTEM')
add_relation('qcCPM')
add_relation('cqCPM')
add_relation('qtTPM')
add_relation('tqTPM')
if cv_link:
add_relation("qcNUMBER")
add_relation("cqNUMBER")
add_relation("qcTIME")
add_relation("cqTIME")
add_relation("qcCELLMATCH")
add_relation("cqCELLMATCH")
if merge_types:
assert not cc_foreign_key
assert not cc_table_match
assert not ct_foreign_key
assert not ct_table_match
assert not tc_foreign_key
assert not tc_table_match
assert not tt_foreign_key
assert cc_max_dist == qq_max_dist
assert tt_max_dist == qq_max_dist
add_relation('xx_default')
self.relation_ids['qc_default'] = self.relation_ids['xx_default']
self.relation_ids['qt_default'] = self.relation_ids['xx_default']
self.relation_ids['cq_default'] = self.relation_ids['xx_default']
self.relation_ids['cc_default'] = self.relation_ids['xx_default']
self.relation_ids['ct_default'] = self.relation_ids['xx_default']
self.relation_ids['tq_default'] = self.relation_ids['xx_default']
self.relation_ids['tc_default'] = self.relation_ids['xx_default']
self.relation_ids['tt_default'] = self.relation_ids['xx_default']
if sc_link:
self.relation_ids['qcCEM'] = self.relation_ids['xx_default']
self.relation_ids['qcCPM'] = self.relation_ids['xx_default']
self.relation_ids['qtTEM'] = self.relation_ids['xx_default']
self.relation_ids['qtTPM'] = self.relation_ids['xx_default']
self.relation_ids['cqCEM'] = self.relation_ids['xx_default']
self.relation_ids['cqCPM'] = self.relation_ids['xx_default']
self.relation_ids['tqTEM'] = self.relation_ids['xx_default']
self.relation_ids['tqTPM'] = self.relation_ids['xx_default']
if cv_link:
self.relation_ids["qcNUMBER"] = self.relation_ids['xx_default']
self.relation_ids["cqNUMBER"] = self.relation_ids['xx_default']
self.relation_ids["qcTIME"] = self.relation_ids['xx_default']
self.relation_ids["cqTIME"] = self.relation_ids['xx_default']
self.relation_ids["qcCELLMATCH"] = self.relation_ids['xx_default']
self.relation_ids["cqCELLMATCH"] = self.relation_ids['xx_default']
for i in range(-qq_max_dist, qq_max_dist + 1):
self.relation_ids['cc_dist', i] = self.relation_ids['qq_dist', i]
self.relation_ids['tt_dist', i] = self.relation_ids['tt_dist', i]
if ff_size is None:
ff_size = hidden_size * 4
self.encoder = transformer.Encoder(
lambda: transformer.EncoderLayer(
hidden_size,
transformer.MultiHeadedAttentionWithRelations(
num_heads,
hidden_size,
dropout),
transformer.PositionwiseFeedForward(
hidden_size,
ff_size,
dropout),
len(self.relation_ids),
dropout),
hidden_size,
num_layers,
tie_layers)
self.align_attn = transformer.PointerWithRelations(hidden_size,
len(self.relation_ids), dropout)