in modeling/model.py [0:0]
def generate(args, batch, model, tokenizer, coref_pred):
'''
Generation of query rewriting
'''
# basic info
input_ids, attention_mask, token_type_ids = batch['input_ids'], batch['attention_mask'], batch['token_type_ids']
batch_size = input_ids.size(0)
ctx_len = input_ids.size(1)
bos_id, eos_id, pad_id, sep_id = tokenizer.convert_tokens_to_ids(['<BOS>', '<EOS>', '<PAD>', '<SEP>'])
assert batch['curr_end_token_idx'][0] == ctx_len
assert batch_size == 1 # don't support batch_size larger thatn 1, when batch_size > 1, the padded input is not straightforward for decoding.
# add <SEP> token to start decoding
tokens_to_add = input_ids.new(batch_size, 1).fill_(sep_id)
input_ids = torch.cat([input_ids, tokens_to_add], dim=-1)
attention_mask = _extend_mask(attention_mask)
assert 0 not in attention_mask # since batch_size == 1, no padding happens
past = None
coref_attn = None
finish_sent = [False for _ in range(batch_size)]
binary_class, copy_not_rewrite, binary_class_pred = None, False, None
for i in range(args.dec_max_len):
if past: # with past, the model only needs current input
input_ids_step = input_ids[:, -1].unsqueeze(-1)
if args.task == 'qr_coref' and args.use_coref_attn:
assert coref_attn is not None
else: # only the first step enters here
input_ids_step = input_ids
bi_logits, logits, _, _, past, coref_attn = model(input_ids=input_ids_step, attention_mask=attention_mask, \
token_type_ids=token_type_ids, past=past, predict_mention=False, \
coref_attn=coref_attn, batch=batch, coref_links=coref_pred, step=i)
if args.use_binary_cls and i == 0: # check if to run the rest geenration based on the binary classification result
# bi_logits: (B, T, 2)
binary_class_pred = torch.argmax(bi_logits[:, -1, :], dim=-1)
binary_class_pred = binary_class_pred.tolist()
assert len(binary_class_pred) == 1
if binary_class_pred[0] == 0 and args.copy_not_rewrite: # not rewrite
copy_not_rewrite = True
break
# logits: (B, T, V), T=1 when past is passed
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
attention_mask = _extend_mask(attention_mask)
for bs_idx, token_id in enumerate(next_token):
if finish_sent[bs_idx] is False and token_id.item() == eos_id: # first produce <eos>
finish_sent[bs_idx] = True
if sum(finish_sent) == batch_size:
break
if copy_not_rewrite: # return the input current utterance as rewrite if predicts `not-rewrite`
return binary_class_pred, batch['curr_utt']
# post-process output sentence
sentences = []
for bs_idx in range(batch_size):
gen = tokenizer.decode(input_ids[bs_idx, :]).split()
gen = _post_proc(gen)
sentences.append(' '.join(gen))
assert len(sentences) == 1
return binary_class_pred, sentences