in modules/SwissArmyTransformer/sat/generation/autoregressive_sampling.py [0:0]
def filling_sequence(
model,
seq,
batch_size,
strategy=BaseStrategy(),
max_memory_length=100000,
log_attention_weights=None,
get_masks_and_position_ids=get_masks_and_position_ids_default,
mems=None,
**kw_args
):
'''
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
cache, should be first mems.shape[1] parts of context_tokens.
mems are the first-level citizens here, but we don't assume what is memorized.
input mems are used when multi-phase generation.
'''
assert len(seq.shape) == 1
if hasattr(strategy, 'num_beams') and batch_size < strategy.num_beams:
batch_size = strategy.num_beams
print_rank0(f'Adjust batch_size to {batch_size} due to num_beams. Mute this warning by setting batch_size == num_beams.', level='DEBUG')
# building the initial tokens, attention_mask, and position_ids
context_length = 0
while seq[context_length] >= 0:
context_length += 1 # [0, context_length-1] are given
assert context_length > 0
tokens, attention_mask, position_ids = get_masks_and_position_ids(seq)
tokens = tokens[..., :context_length]
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
# initialize generation
counter = context_length - 1 # Last fixed index is ``counter''
index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
mems_cross = None
# step-by-step generation
while counter < len(seq) - 1:
# Now, we want to generate seq[counter + 1],
# token[:, index: counter+1] needs forwarding.
if seq[counter + 1] >= 0: # provided
tokens = torch.cat(
(
tokens,
seq[counter+1: counter+2].expand(tokens.shape[0], 1)
), dim=1
)
counter += 1
continue
# forward
if log_attention_weights is not None:
log_attention_weights_part = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
else:
log_attention_weights_part = None
if mems_cross is not None or 'encoder_outputs' not in kw_args:
kw_args['encoder_outputs'] = mems_cross
logits, *output_per_layers = model(
input_ids=tokens[:, index:],
position_ids=position_ids[..., index: counter+1],
attention_mask=attention_mask[..., index: counter+1, :counter+1], # TODO memlen
mems=mems,
log_attention_weights=log_attention_weights_part,
**kw_args
)
if len(output_per_layers) > 0 and 'mem_cross' in output_per_layers[0]:
mems_cross = [mem['mem_cross'] for mem in output_per_layers]
mem_kv = [o['mem_kv'] for o in output_per_layers]
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1
index = counter
# sampling
logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
tokens = tokens.expand(batch_size, -1)
tokens, mems = strategy.forward(logits, tokens, mems)
if strategy.is_done:
break
return strategy.finalize(tokens, mems)