def AverageLagging()

in simuleval/metrics/latency.py [0:0]


def AverageLagging(delays, src_lens, tgt_lens, ref_lens=None, target_padding_mask=None):
    """
    Function to calculate Average Lagging from
    STACL: Simultaneous Translation with Implicit Anticipation
    and Controllable Latency using Prefix-to-Prefix Framework
    (https://arxiv.org/abs/1810.08398)
    Delays are monotonic steps, range from 1 to src_len.
    Give src x tgt y, AP is calculated as:
    AL = 1 / tau sum_i^tau delays_i - (i - 1) / gamma
    Where
    gamma = |y| / |x|
    tau = argmin_i(delays_i = |x|)

    When reference was given, |y| would be the reference length
    """
    bsz, max_tgt_len = delays.size()
    if ref_lens is not None:
        max_tgt_len = ref_lens.max().long()
        tgt_lens = ref_lens

    # tau = argmin_i(delays_i = |x|)
    # Only consider the delays that has already larger than src_lens
    lagging_padding_mask = delays >= src_lens
    # Padding one token at beginning to consider at least one delays that
    # larget than src_lens
    lagging_padding_mask = torch.nn.functional.pad(
        lagging_padding_mask, (1, 0))[:, :-1]

    if target_padding_mask is not None:
        lagging_padding_mask = lagging_padding_mask.masked_fill(
            target_padding_mask, True)

    # oracle delays are the delay for the oracle system which goes diagonally
    oracle_delays = (
        torch.arange(max_tgt_len)
        .unsqueeze(0)
        .type_as(delays)
        .expand([bsz, max_tgt_len])
    ) * src_lens / tgt_lens

    if delays.size(1) < max_tgt_len:
        oracle_delays = oracle_delays[:, :delays.size(1)]

    if delays.size(1) > max_tgt_len:
        oracle_delays = torch.cat(
            [
                oracle_delays,
                oracle_delays[:,-1]
                * oracle_delays.new_ones(
                    [delays.size(0), delays.size(1) - max_tgt_len]
                )
            ],
            dim=1
        )

    lagging = delays - oracle_delays
    lagging = lagging.masked_fill(lagging_padding_mask, 0)

    # tau is the cut-off step
    tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=1)
    AL = lagging.sum(dim=1) / tau

    return AL