def generate()

in code/src/model/lm.py [0:0]


    def generate(self, attr, max_len=200, temperature=-1):
        """
        Generate sentences from attributes.
        """
        assert temperature > 0 or temperature == -1

        bs = attr.size(0)
        cur_len = 1
        decoded = torch.LongTensor(max_len, bs).fill_(self.pad_index).to(attr.device)
        decoded[0] = self.bos_index
        h_c = None

        # decoding
        while cur_len < max_len:

            # previous word embeddings
            if cur_len == 1 and self.bos_attr != '':
                embeddings = self.get_bos_attr(attr)
            else:
                embeddings = self.embeddings(decoded[cur_len - 1])
            embeddings = F.dropout(embeddings, p=self.dropout, training=self.training)

            lstm_output, h_c = self.lstm(embeddings.unsqueeze(0), h_c)
            output = F.dropout(lstm_output, p=self.dropout, training=self.training).view(bs, self.hidden_dim)
            scores = self.proj(output)
            if self.bias_attr != '':
                scores = scores + self.get_bias_attr(attr)
            assert scores.size() == (bs, self.n_words)

            # select next words: sample or argmax
            if temperature > 0:
                next_words = torch.multinomial(F.softmax(scores / temperature, dim=1), 1).squeeze(1)
            else:
                next_words = scores.max(1)[1]
            assert next_words.size() == (bs,)
            decoded[cur_len] = next_words
            cur_len += 1

            # stop when there is a </s> in each sentence
            if decoded.eq(self.eos_index).sum(0).ne(0).sum() == bs:
                break

        # compute the length of each generated sentence, and
        # put some padding after the end of each sentence
        lengths = torch.LongTensor(bs).fill_(cur_len)
        for i in range(bs):
            for j in range(cur_len):
                if decoded[j, i] == self.eos_index:
                    if j + 1 < max_len:
                        decoded[j + 1:, i] = self.pad_index
                    lengths[i] = j + 1
                    break
            if lengths[i] == max_len:
                decoded[-1, i] = self.eos_index

        return decoded[:cur_len], lengths