modules/SwissArmyTransformer/sat/generation/autoregressive_sampling.py (222 lines of code) (raw):
# -*- encoding: utf-8 -*-
'''
@File : autoregressive_sampling.py
@Time : 2021/10/08 15:43:59
@Author : Ming Ding
@Contact : dm18@mails.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
from .sampling_strategies import BaseStrategy
from sat.helpers import print_rank0
def get_masks_and_position_ids_default(seq):
tokens = seq.unsqueeze(0)
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
attention_mask.tril_()
attention_mask.unsqueeze_(1)
position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device)
position_ids = position_ids.unsqueeze(0)
return tokens, attention_mask, position_ids
def update_mems(hiddens, mems, max_memory_length):
'''
hiddens: list (num_layers) of [batch, query_length, 2d]
mems: None or [num_layers, batch, memory_length, 2d]
'''
if hiddens is None:
return None
hiddens = torch.stack(hiddens)
memory_length = mems.shape[2] if mems is not None else 0
query_length = hiddens.shape[2]
new_memory_length = min(max_memory_length, memory_length + query_length)
with torch.no_grad():
if new_memory_length <= query_length:
return hiddens[:, :, -new_memory_length:]
else:
if mems.shape[1] < hiddens.shape[1]:
mems = mems.expand(-1, hiddens.shape[1], -1, -1)
return torch.cat(
(mems[:, :, -new_memory_length+query_length:], hiddens),
dim=2
)
def filling_sequence(
model,
seq,
batch_size,
strategy=BaseStrategy(),
max_memory_length=100000,
log_attention_weights=None,
get_masks_and_position_ids=get_masks_and_position_ids_default,
mems=None,
**kw_args
):
'''
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
cache, should be first mems.shape[1] parts of context_tokens.
mems are the first-level citizens here, but we don't assume what is memorized.
input mems are used when multi-phase generation.
'''
assert len(seq.shape) == 1
if hasattr(strategy, 'num_beams') and batch_size < strategy.num_beams:
batch_size = strategy.num_beams
print_rank0(f'Adjust batch_size to {batch_size} due to num_beams. Mute this warning by setting batch_size == num_beams.', level='DEBUG')
# building the initial tokens, attention_mask, and position_ids
context_length = 0
while seq[context_length] >= 0:
context_length += 1 # [0, context_length-1] are given
assert context_length > 0
tokens, attention_mask, position_ids = get_masks_and_position_ids(seq)
tokens = tokens[..., :context_length]
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
# initialize generation
counter = context_length - 1 # Last fixed index is ``counter''
index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
mems_cross = None
# step-by-step generation
while counter < len(seq) - 1:
# Now, we want to generate seq[counter + 1],
# token[:, index: counter+1] needs forwarding.
if seq[counter + 1] >= 0: # provided
tokens = torch.cat(
(
tokens,
seq[counter+1: counter+2].expand(tokens.shape[0], 1)
), dim=1
)
counter += 1
continue
# forward
if log_attention_weights is not None:
log_attention_weights_part = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
else:
log_attention_weights_part = None
if mems_cross is not None or 'encoder_outputs' not in kw_args:
kw_args['encoder_outputs'] = mems_cross
logits, *output_per_layers = model(
input_ids=tokens[:, index:],
position_ids=position_ids[..., index: counter+1],
attention_mask=attention_mask[..., index: counter+1, :counter+1], # TODO memlen
mems=mems,
log_attention_weights=log_attention_weights_part,
**kw_args
)
if len(output_per_layers) > 0 and 'mem_cross' in output_per_layers[0]:
mems_cross = [mem['mem_cross'] for mem in output_per_layers]
mem_kv = [o['mem_kv'] for o in output_per_layers]
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1
index = counter
# sampling
logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
tokens = tokens.expand(batch_size, -1)
tokens, mems = strategy.forward(logits, tokens, mems)
if strategy.is_done:
break
return strategy.finalize(tokens, mems)
def stream_filling_sequence(
model,
seq,
batch_size,
strategy=BaseStrategy(),
max_memory_length=100000,
log_attention_weights=None,
get_masks_and_position_ids=get_masks_and_position_ids_default,
mems=None,
**kw_args
):
'''
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
cache, should be first mems.shape[1] parts of context_tokens.
mems are the first-level citizens here, but we don't assume what is memorized.
input mems are used when multi-phase generation.
usage example:
filling_stream = stream_filling_sequence(
model, seq,
batch_size=1
)
for tokens in filling_stream:
pass # you can do something or nothing at all
output = strategy.finalize(tokens, None)[0]
'''
assert len(seq.shape) == 1
if hasattr(strategy, 'num_beams') and batch_size < strategy.num_beams:
batch_size = strategy.num_beams
print_rank0(f'Adjust batch_size to {batch_size} due to num_beams. Mute this warning by setting batch_size == num_beams.', level='DEBUG')
# building the initial tokens, attention_mask, and position_ids
context_length = 0
while seq[context_length] >= 0:
context_length += 1 # [0, context_length-1] are given
assert context_length > 0
tokens, attention_mask, position_ids = get_masks_and_position_ids(seq)
tokens = tokens[..., :context_length]
if attention_mask.dtype != torch.bool:
attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
# initialize generation
counter = context_length - 1 # Last fixed index is ``counter''
index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
mems_cross = None
# step-by-step generation
while counter < len(seq) - 1:
# Now, we want to generate seq[counter + 1],
# token[:, index: counter+1] needs forwarding.
if seq[counter + 1] >= 0: # provided
tokens = torch.cat(
(
tokens,
seq[counter+1: counter+2].expand(tokens.shape[0], 1)
), dim=1
)
counter += 1
continue
# forward
if log_attention_weights is not None:
log_attention_weights_part = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
else:
log_attention_weights_part = None
if mems_cross is not None or 'encoder_outputs' not in kw_args:
kw_args['encoder_outputs'] = mems_cross
logits, *output_per_layers = model(
input_ids=tokens[:, index:],
position_ids=position_ids[..., index: counter+1],
attention_mask=attention_mask[..., index: counter+1, :counter+1], # TODO memlen
mems=mems,
log_attention_weights=log_attention_weights_part,
**kw_args
)
if len(output_per_layers) > 0 and 'mem_cross' in output_per_layers[0]:
mems_cross = [mem['mem_cross'] for mem in output_per_layers]
mem_kv = [o['mem_kv'] for o in output_per_layers]
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1
index = counter
# sampling
logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
tokens = tokens.expand(batch_size, -1)
tokens, mems = strategy.forward(logits, tokens, mems)
yield tokens, mems
if strategy.is_done:
break
def evaluate_perplexity(model, tokens, attention_mask, position_ids, loss_mask, invalid_slices=[], reduction='mean'):
# sanity check
assert len(tokens.shape) <= 2 and len(loss_mask.shape)
if len(tokens.shape) == 1:
tokens = tokens.unsqueeze(0)
if len(loss_mask.shape) == 1:
loss_mask = loss_mask.unsqueeze(0).expand(tokens.shape)
pad_pos = tokens < 0
if pad_pos.any():
print_rank0('Find -1 in tokens, automatically ignore them.', level='DEBUG')
tokens[pad_pos] = 0
loss_mask[pad_pos] = 0
attention_mask = attention_mask.type_as(next(model.parameters()))
logits = model(tokens, position_ids, attention_mask)[0]
logits = logits.float()
for slc in invalid_slices:
logits[..., slc] = -float('Inf')
log_probs = torch.log(torch.nn.functional.softmax(logits, dim=-1))
pred = log_probs[:, :-1, :]
target = tokens[:, 1:].unsqueeze(-1)
loss_mask = loss_mask[..., 1:]
scores = -1 * torch.gather(pred, dim=2, index=target).squeeze(-1) # [batch_size, seq_len-1]
if reduction == 'mean':
return (scores * loss_mask).sum(dim=-1) / loss_mask.sum(dim=-1)
elif reduction == 'none':
return (scores * loss_mask)
else:
raise ValueError('Unknown reduction type')