in simuleval/metrics/latency.py [0:0]
def AverageLagging(delays, src_lens, tgt_lens, ref_lens=None, target_padding_mask=None):
"""
Function to calculate Average Lagging from
STACL: Simultaneous Translation with Implicit Anticipation
and Controllable Latency using Prefix-to-Prefix Framework
(https://arxiv.org/abs/1810.08398)
Delays are monotonic steps, range from 1 to src_len.
Give src x tgt y, AP is calculated as:
AL = 1 / tau sum_i^tau delays_i - (i - 1) / gamma
Where
gamma = |y| / |x|
tau = argmin_i(delays_i = |x|)
When reference was given, |y| would be the reference length
"""
bsz, max_tgt_len = delays.size()
if ref_lens is not None:
max_tgt_len = ref_lens.max().long()
tgt_lens = ref_lens
# tau = argmin_i(delays_i = |x|)
# Only consider the delays that has already larger than src_lens
lagging_padding_mask = delays >= src_lens
# Padding one token at beginning to consider at least one delays that
# larget than src_lens
lagging_padding_mask = torch.nn.functional.pad(
lagging_padding_mask, (1, 0))[:, :-1]
if target_padding_mask is not None:
lagging_padding_mask = lagging_padding_mask.masked_fill(
target_padding_mask, True)
# oracle delays are the delay for the oracle system which goes diagonally
oracle_delays = (
torch.arange(max_tgt_len)
.unsqueeze(0)
.type_as(delays)
.expand([bsz, max_tgt_len])
) * src_lens / tgt_lens
if delays.size(1) < max_tgt_len:
oracle_delays = oracle_delays[:, :delays.size(1)]
if delays.size(1) > max_tgt_len:
oracle_delays = torch.cat(
[
oracle_delays,
oracle_delays[:,-1]
* oracle_delays.new_ones(
[delays.size(0), delays.size(1) - max_tgt_len]
)
],
dim=1
)
lagging = delays - oracle_delays
lagging = lagging.masked_fill(lagging_padding_mask, 0)
# tau is the cut-off step
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=1)
AL = lagging.sum(dim=1) / tau
return AL