def __init__()

in rat-sql-gap/seq2struct/models/nl2code/decoder.py [0:0]


    def __init__(
            self, 
            device,
            preproc,
            #
            rule_emb_size=128,
            node_embed_size=64,
            # TODO: This should be automatically inferred from encoder
            enc_recurrent_size=256,
            recurrent_size=256,
            dropout=0.,
            desc_attn='bahdanau',
            copy_pointer=None,
            multi_loss_type='logsumexp',
            sup_att=None,
            use_align_mat=False,
            use_align_loss=False,
            enumerate_order=False,
            loss_type="softmax"):
        super().__init__()
        self._device = device
        self.preproc = preproc
        self.ast_wrapper = preproc.ast_wrapper
        self.terminal_vocab = preproc.vocab

        self.rule_emb_size = rule_emb_size
        self.node_emb_size = node_embed_size
        self.enc_recurrent_size = enc_recurrent_size
        self.recurrent_size = recurrent_size

        self.rules_index = {v: idx for idx, v in enumerate(self.preproc.all_rules)}
        self.use_align_mat = use_align_mat
        self.use_align_loss = use_align_loss
        self.enumerate_order = enumerate_order

        if use_align_mat:
            from seq2struct.models.spider import spider_dec_func
            self.compute_align_loss = lambda *args: \
                spider_dec_func.compute_align_loss(self, *args)
            self.compute_pointer_with_align = lambda *args: \
                spider_dec_func.compute_pointer_with_align(self, *args)

        if self.preproc.use_seq_elem_rules:
            self.node_type_vocab = vocab.Vocab(
                    sorted(self.preproc.primitive_types) +
                    sorted(self.ast_wrapper.custom_primitive_types) +
                    sorted(self.preproc.sum_type_constructors.keys()) +
                    sorted(self.preproc.field_presence_infos.keys()) +
                    sorted(self.preproc.seq_lengths.keys()),
                    special_elems=())
        else:
            self.node_type_vocab = vocab.Vocab(
                    sorted(self.preproc.primitive_types) +
                    sorted(self.ast_wrapper.custom_primitive_types) +
                    sorted(self.ast_wrapper.sum_types.keys()) +
                    sorted(self.ast_wrapper.singular_types.keys()) +
                    sorted(self.preproc.seq_lengths.keys()),
                    special_elems=())

        self.state_update = variational_lstm.RecurrentDropoutLSTMCell(
                input_size=self.rule_emb_size * 2 + self.enc_recurrent_size + self.recurrent_size + self.node_emb_size,
                hidden_size=self.recurrent_size,
                dropout=dropout)

        self.attn_type = desc_attn
        if desc_attn == 'bahdanau':
            self.desc_attn = attention.BahdanauAttention(
                    query_size=self.recurrent_size,
                    value_size=self.enc_recurrent_size,
                    proj_size=50)
        elif desc_attn == 'mha':
            self.desc_attn = attention.MultiHeadedAttention(
                    h=8,
                    query_size=self.recurrent_size,
                    value_size=self.enc_recurrent_size)
        elif desc_attn == 'mha-1h':
            self.desc_attn = attention.MultiHeadedAttention(
                    h=1,
                    query_size=self.recurrent_size,
                    value_size=self.enc_recurrent_size)
        elif desc_attn == 'sep':
            self.question_attn = attention.MultiHeadedAttention(
                    h=1,
                    query_size=self.recurrent_size,
                    value_size=self.enc_recurrent_size)
            self.schema_attn = attention.MultiHeadedAttention(
                    h=1,
                    query_size=self.recurrent_size,
                    value_size=self.enc_recurrent_size)       
        else:
            # TODO: Figure out how to get right sizes (query, value) to module
            self.desc_attn = desc_attn
        self.sup_att = sup_att

        self.rule_logits = torch.nn.Sequential(
                torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
                torch.nn.Tanh(),
                torch.nn.Linear(self.rule_emb_size, len(self.rules_index)))
        self.rule_embedding = torch.nn.Embedding(
                num_embeddings=len(self.rules_index),
                embedding_dim=self.rule_emb_size)

        self.gen_logodds = torch.nn.Linear(self.recurrent_size, 1)
        self.terminal_logits = torch.nn.Sequential(
                torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
                torch.nn.Tanh(),
                torch.nn.Linear(self.rule_emb_size, len(self.terminal_vocab)))
        self.terminal_embedding = torch.nn.Embedding(
                num_embeddings=len(self.terminal_vocab),
                embedding_dim=self.rule_emb_size)
        if copy_pointer is None:
            self.copy_pointer = attention.BahdanauPointer(
                    query_size=self.recurrent_size,
                    key_size=self.enc_recurrent_size,
                    proj_size=50)
        else:
            # TODO: Figure out how to get right sizes (query, key) to module
            self.copy_pointer = copy_pointer
        if multi_loss_type == 'logsumexp':
            self.multi_loss_reduction = lambda logprobs: -torch.logsumexp(logprobs, dim=1)
        elif multi_loss_type == 'mean':
            self.multi_loss_reduction = lambda logprobs: -torch.mean(logprobs, dim=1)
        
        self.pointers = torch.nn.ModuleDict()
        self.pointer_action_emb_proj = torch.nn.ModuleDict()
        for pointer_type in self.preproc.grammar.pointers:
            self.pointers[pointer_type] = attention.ScaledDotProductPointer(
                    query_size=self.recurrent_size,
                    key_size=self.enc_recurrent_size)
            self.pointer_action_emb_proj[pointer_type] = torch.nn.Linear(
                    self.enc_recurrent_size, self.rule_emb_size)

        self.node_type_embedding = torch.nn.Embedding(
                num_embeddings=len(self.node_type_vocab),
                embedding_dim=self.node_emb_size)

        # TODO batching
        self.zero_rule_emb = torch.zeros(1, self.rule_emb_size, device=self._device)
        self.zero_recurrent_emb = torch.zeros(1, self.recurrent_size, device=self._device)
        if loss_type == "softmax":
            self.xent_loss = torch.nn.CrossEntropyLoss(reduction='none')
        elif loss_type == "entmax":
            self.xent_loss = entmax.entmax15_loss
        elif loss_type == "sparsemax":
            self.xent_loss = entmax.sparsemax_loss
        elif loss_type == "label_smooth":
            self.xent_loss = self.label_smooth_loss