in simuleval/metrics/latency.py [0:0]
def latency_metric(func):
    def prepare_latency_metric(
        delays,
        src_lens,
        ref_lens=None,
        target_padding_mask=None,
    ):
        """
        delays: bsz, tgt_len
        src_lens: bsz
        target_padding_mask: bsz, tgt_len
        """
        if isinstance(delays, list):
            delays = torch.FloatTensor(delays).unsqueeze(0)
        if len(delays.size()) == 1:
            delays = delays.view(1, -1)
        if isinstance(src_lens, list):
            src_lens = torch.FloatTensor(src_lens)
        if isinstance(src_lens, numbers.Number):
            src_lens = torch.FloatTensor([src_lens])
        if len(src_lens.size()) == 1:
            src_lens = src_lens.view(-1, 1)
        src_lens = src_lens.type_as(delays)
        if ref_lens is not None:
            if isinstance(ref_lens, list):
                ref_lens = torch.FloatTensor(ref_lens)
            if isinstance(ref_lens, numbers.Number):
                ref_lens = torch.FloatTensor([ref_lens])
            if len(ref_lens.size()) == 1:
                ref_lens = ref_lens.view(-1, 1)
            ref_lens = ref_lens.type_as(delays)
        if target_padding_mask is not None:
            tgt_lens = delays.size(-1) - target_padding_mask.sum(dim=1)
            delays = delays.masked_fill(target_padding_mask, 0)
        else:
            tgt_lens = torch.ones_like(src_lens) * delays.size(1)
        tgt_lens = tgt_lens.view(-1, 1)
        return delays, src_lens, tgt_lens, ref_lens, target_padding_mask
    def latency_wrapper(
        delays, src_lens, ref_lens=None, target_padding_mask=None
    ):
        delays, src_lens, tgt_lens, ref_lens, target_padding_mask = prepare_latency_metric(
            delays, src_lens, ref_lens, target_padding_mask)
        return func(delays, src_lens, tgt_lens, ref_lens, target_padding_mask)
    return latency_wrapper