in simuleval/metrics/latency.py [0:0]
def latency_metric(func):
def prepare_latency_metric(
delays,
src_lens,
ref_lens=None,
target_padding_mask=None,
):
"""
delays: bsz, tgt_len
src_lens: bsz
target_padding_mask: bsz, tgt_len
"""
if isinstance(delays, list):
delays = torch.FloatTensor(delays).unsqueeze(0)
if len(delays.size()) == 1:
delays = delays.view(1, -1)
if isinstance(src_lens, list):
src_lens = torch.FloatTensor(src_lens)
if isinstance(src_lens, numbers.Number):
src_lens = torch.FloatTensor([src_lens])
if len(src_lens.size()) == 1:
src_lens = src_lens.view(-1, 1)
src_lens = src_lens.type_as(delays)
if ref_lens is not None:
if isinstance(ref_lens, list):
ref_lens = torch.FloatTensor(ref_lens)
if isinstance(ref_lens, numbers.Number):
ref_lens = torch.FloatTensor([ref_lens])
if len(ref_lens.size()) == 1:
ref_lens = ref_lens.view(-1, 1)
ref_lens = ref_lens.type_as(delays)
if target_padding_mask is not None:
tgt_lens = delays.size(-1) - target_padding_mask.sum(dim=1)
delays = delays.masked_fill(target_padding_mask, 0)
else:
tgt_lens = torch.ones_like(src_lens) * delays.size(1)
tgt_lens = tgt_lens.view(-1, 1)
return delays, src_lens, tgt_lens, ref_lens, target_padding_mask
def latency_wrapper(
delays, src_lens, ref_lens=None, target_padding_mask=None
):
delays, src_lens, tgt_lens, ref_lens, target_padding_mask = prepare_latency_metric(
delays, src_lens, ref_lens, target_padding_mask)
return func(delays, src_lens, tgt_lens, ref_lens, target_padding_mask)
return latency_wrapper