def generate()

in modeling/model.py [0:0]


def generate(args, batch, model, tokenizer, coref_pred):
	'''
		Generation of query rewriting
	'''
	# basic info
	input_ids, attention_mask, token_type_ids = batch['input_ids'], batch['attention_mask'], batch['token_type_ids']
	batch_size = input_ids.size(0)
	ctx_len = input_ids.size(1)
	bos_id, eos_id, pad_id, sep_id = tokenizer.convert_tokens_to_ids(['<BOS>', '<EOS>', '<PAD>', '<SEP>'])
	assert batch['curr_end_token_idx'][0] == ctx_len
	assert batch_size == 1 # don't support batch_size larger thatn 1, when batch_size > 1, the padded input is not straightforward for decoding.

	# add <SEP> token to start decoding
	tokens_to_add = input_ids.new(batch_size, 1).fill_(sep_id)
	input_ids = torch.cat([input_ids, tokens_to_add], dim=-1)
	attention_mask = _extend_mask(attention_mask)
	assert 0 not in attention_mask # since batch_size == 1, no padding happens

	past = None
	coref_attn = None
	finish_sent = [False for _ in range(batch_size)]
	binary_class, copy_not_rewrite, binary_class_pred = None, False, None
	for i in range(args.dec_max_len):
		if past: # with past, the model only needs current input
			input_ids_step = input_ids[:, -1].unsqueeze(-1)
			if args.task == 'qr_coref' and args.use_coref_attn:
				assert coref_attn is not None

		else: # only the first step enters here
			input_ids_step = input_ids

		bi_logits, logits, _, _, past, coref_attn = model(input_ids=input_ids_step, attention_mask=attention_mask, \
														token_type_ids=token_type_ids, past=past, predict_mention=False, \
														coref_attn=coref_attn, batch=batch, coref_links=coref_pred, step=i)

		if args.use_binary_cls and i == 0: # check if to run the rest geenration based on the binary classification result
			# bi_logits: (B, T, 2)
			binary_class_pred = torch.argmax(bi_logits[:, -1, :], dim=-1)
			binary_class_pred = binary_class_pred.tolist()
			assert len(binary_class_pred) == 1
			if binary_class_pred[0] == 0 and args.copy_not_rewrite: # not rewrite
				copy_not_rewrite = True
				break

		# logits: (B, T, V), T=1 when past is passed
		next_token_logits = logits[:, -1, :]
		next_token = torch.argmax(next_token_logits, dim=-1)
		input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
		attention_mask = _extend_mask(attention_mask)

		for bs_idx, token_id in enumerate(next_token):
			if finish_sent[bs_idx] is False and token_id.item() == eos_id: # first produce <eos>
				finish_sent[bs_idx] = True

		if sum(finish_sent) == batch_size:
			break

	if copy_not_rewrite: # return the input current utterance as rewrite if predicts `not-rewrite`
		return binary_class_pred, batch['curr_utt']

	# post-process output sentence
	sentences = []
	for bs_idx in range(batch_size):
		gen = tokenizer.decode(input_ids[bs_idx, :]).split()
		gen = _post_proc(gen)
		sentences.append(' '.join(gen))
	assert len(sentences) == 1
	return binary_class_pred, sentences