def generate()

in code/src/model/attention.py [0:0]


    def generate(self, encoded, attr, max_len=200, sample=False, temperature=None):
        """
        Generate a sentence from a given initial state.
        Input:
            - FloatTensor of size (batch_size, hidden_dim) representing
              sentences encoded in the latent space
        Output:
            - LongTensor of size (seq_len, batch_size), word indices
            - LongTensor of size (batch_size,), sentence x_len
        """
        latent = encoded.dec_input
        x_len = encoded.input_len
        is_cuda = latent.is_cuda
        one_hot = None  # [] if temperature is not None else None

        # check inputs
        assert latent.size() == (x_len.max(), x_len.size(0), self.emb_dim)
        assert attr.size() == (x_len.size(0), len(self.attributes))
        assert (sample is True) ^ (temperature is None)

        # initialize generated sentences batch
        slen, bs = latent.size(0), latent.size(1)
        assert x_len.max() == slen and x_len.size(0) == bs
        cur_len = 1
        decoded = torch.LongTensor(max_len, bs).fill_(self.pad_index)
        decoded = decoded.cuda() if is_cuda else decoded
        decoded[0] = self.bos_index

        # compute attention
        mask = get_mask(x_len, True, cuda=is_cuda) == 0
        h_c_1, h_c_2 = None, None
        hidden_states = [latent.data.new(1, bs, self.hidden_dim).zero_()]

        while cur_len < max_len:
            # previous word embeddings
            if cur_len == 1 and self.bos_attr != '':
                embeddings = self.get_bos_attr(attr)
            else:
                embeddings = self.embeddings(decoded[cur_len - 1])
            embeddings = F.dropout(embeddings, p=self.dropout, training=self.training)

            # attention layer
            attention = self.get_attention(latent, hidden_states[-1][0], embeddings, mask)

            # lstm step
            lstm_input = embeddings.unsqueeze(0)
            if self.input_feeding:
                lstm_input = torch.cat([lstm_input, attention], 2)
            lstm_output, h_c_1 = self.lstm1(lstm_input, h_c_1)
            assert lstm_output.size() == (1, bs, self.hidden_dim)
            hidden_states.append(lstm_output)

            # lstm (layers > 1)
            if self.n_dec_layers > 1:
                lstm_output = F.dropout(lstm_output, p=self.dropout, training=self.training)
                if not self.input_feeding:
                    lstm_output = torch.cat([lstm_output, attention], 2)
                lstm_output, h_c_2 = self.lstm2(lstm_output, h_c_2)
                assert lstm_output.size() == (1, bs, self.hidden_dim)

            # word scores
            output = F.dropout(lstm_output, p=self.dropout, training=self.training).view(-1, self.hidden_dim)
            if self.lstm_proj_layer is not None:
                output = F.relu(self.lstm_proj_layer(output))
            scores = self.proj(output).view(bs, self.n_words)
            if self.bias_attr != '':
                scores = scores + self.get_bias_attr(attr)
            scores = scores.data

            # select next words: sample (Gumbel Softmax) or one-hot
            if sample:
                # if temperature is not None:
                #     gumbel = gumbel_softmax(scores, temperature, hard=True)
                #     next_words = gumbel.max(1)[1]
                #     one_hot.append(gumbel)
                # else:
                next_words = torch.multinomial(F.softmax(scores / temperature, dim=1), 1).squeeze(1)
            else:
                next_words = torch.topk(scores, 1)[1].squeeze(1)
            assert next_words.size() == (bs,)
            decoded[cur_len] = next_words
            cur_len += 1

            # stop when there is a </s> in each sentence
            if decoded.eq(self.eos_index).sum(0).ne(0).sum() == bs:
                break

        # compute the length of each generated sentence, and
        # put some padding after the end of each sentence
        lengths = torch.LongTensor(bs).fill_(cur_len)
        for i in range(bs):
            for j in range(cur_len):
                if decoded[j, i] == self.eos_index:
                    if j + 1 < max_len:
                        decoded[j + 1:, i] = self.pad_index
                    lengths[i] = j + 1
                    break
            if lengths[i] == max_len:
                decoded[-1, i] = self.eos_index

        if one_hot is not None:
            one_hot = torch.cat([x.unsqueeze(0) for x in one_hot], 0)
            assert one_hot.size() == (cur_len - 1, bs, self.n_words)
        return decoded[:cur_len], lengths, one_hot