in rat-sql-gap/seq2struct/models/nl2code/decoder.py [0:0]
def __init__(
self,
device,
preproc,
#
rule_emb_size=128,
node_embed_size=64,
# TODO: This should be automatically inferred from encoder
enc_recurrent_size=256,
recurrent_size=256,
dropout=0.,
desc_attn='bahdanau',
copy_pointer=None,
multi_loss_type='logsumexp',
sup_att=None,
use_align_mat=False,
use_align_loss=False,
enumerate_order=False,
loss_type="softmax"):
super().__init__()
self._device = device
self.preproc = preproc
self.ast_wrapper = preproc.ast_wrapper
self.terminal_vocab = preproc.vocab
self.rule_emb_size = rule_emb_size
self.node_emb_size = node_embed_size
self.enc_recurrent_size = enc_recurrent_size
self.recurrent_size = recurrent_size
self.rules_index = {v: idx for idx, v in enumerate(self.preproc.all_rules)}
self.use_align_mat = use_align_mat
self.use_align_loss = use_align_loss
self.enumerate_order = enumerate_order
if use_align_mat:
from seq2struct.models.spider import spider_dec_func
self.compute_align_loss = lambda *args: \
spider_dec_func.compute_align_loss(self, *args)
self.compute_pointer_with_align = lambda *args: \
spider_dec_func.compute_pointer_with_align(self, *args)
if self.preproc.use_seq_elem_rules:
self.node_type_vocab = vocab.Vocab(
sorted(self.preproc.primitive_types) +
sorted(self.ast_wrapper.custom_primitive_types) +
sorted(self.preproc.sum_type_constructors.keys()) +
sorted(self.preproc.field_presence_infos.keys()) +
sorted(self.preproc.seq_lengths.keys()),
special_elems=())
else:
self.node_type_vocab = vocab.Vocab(
sorted(self.preproc.primitive_types) +
sorted(self.ast_wrapper.custom_primitive_types) +
sorted(self.ast_wrapper.sum_types.keys()) +
sorted(self.ast_wrapper.singular_types.keys()) +
sorted(self.preproc.seq_lengths.keys()),
special_elems=())
self.state_update = variational_lstm.RecurrentDropoutLSTMCell(
input_size=self.rule_emb_size * 2 + self.enc_recurrent_size + self.recurrent_size + self.node_emb_size,
hidden_size=self.recurrent_size,
dropout=dropout)
self.attn_type = desc_attn
if desc_attn == 'bahdanau':
self.desc_attn = attention.BahdanauAttention(
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size,
proj_size=50)
elif desc_attn == 'mha':
self.desc_attn = attention.MultiHeadedAttention(
h=8,
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size)
elif desc_attn == 'mha-1h':
self.desc_attn = attention.MultiHeadedAttention(
h=1,
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size)
elif desc_attn == 'sep':
self.question_attn = attention.MultiHeadedAttention(
h=1,
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size)
self.schema_attn = attention.MultiHeadedAttention(
h=1,
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size)
else:
# TODO: Figure out how to get right sizes (query, value) to module
self.desc_attn = desc_attn
self.sup_att = sup_att
self.rule_logits = torch.nn.Sequential(
torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
torch.nn.Tanh(),
torch.nn.Linear(self.rule_emb_size, len(self.rules_index)))
self.rule_embedding = torch.nn.Embedding(
num_embeddings=len(self.rules_index),
embedding_dim=self.rule_emb_size)
self.gen_logodds = torch.nn.Linear(self.recurrent_size, 1)
self.terminal_logits = torch.nn.Sequential(
torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
torch.nn.Tanh(),
torch.nn.Linear(self.rule_emb_size, len(self.terminal_vocab)))
self.terminal_embedding = torch.nn.Embedding(
num_embeddings=len(self.terminal_vocab),
embedding_dim=self.rule_emb_size)
if copy_pointer is None:
self.copy_pointer = attention.BahdanauPointer(
query_size=self.recurrent_size,
key_size=self.enc_recurrent_size,
proj_size=50)
else:
# TODO: Figure out how to get right sizes (query, key) to module
self.copy_pointer = copy_pointer
if multi_loss_type == 'logsumexp':
self.multi_loss_reduction = lambda logprobs: -torch.logsumexp(logprobs, dim=1)
elif multi_loss_type == 'mean':
self.multi_loss_reduction = lambda logprobs: -torch.mean(logprobs, dim=1)
self.pointers = torch.nn.ModuleDict()
self.pointer_action_emb_proj = torch.nn.ModuleDict()
for pointer_type in self.preproc.grammar.pointers:
self.pointers[pointer_type] = attention.ScaledDotProductPointer(
query_size=self.recurrent_size,
key_size=self.enc_recurrent_size)
self.pointer_action_emb_proj[pointer_type] = torch.nn.Linear(
self.enc_recurrent_size, self.rule_emb_size)
self.node_type_embedding = torch.nn.Embedding(
num_embeddings=len(self.node_type_vocab),
embedding_dim=self.node_emb_size)
# TODO batching
self.zero_rule_emb = torch.zeros(1, self.rule_emb_size, device=self._device)
self.zero_recurrent_emb = torch.zeros(1, self.recurrent_size, device=self._device)
if loss_type == "softmax":
self.xent_loss = torch.nn.CrossEntropyLoss(reduction='none')
elif loss_type == "entmax":
self.xent_loss = entmax.entmax15_loss
elif loss_type == "sparsemax":
self.xent_loss = entmax.sparsemax_loss
elif loss_type == "label_smooth":
self.xent_loss = self.label_smooth_loss