def latency_metric()

in simuleval/metrics/latency.py [0:0]


def latency_metric(func):
    def prepare_latency_metric(
        delays,
        src_lens,
        ref_lens=None,
        target_padding_mask=None,
    ):
        """
        delays: bsz, tgt_len
        src_lens: bsz
        target_padding_mask: bsz, tgt_len
        """
        if isinstance(delays, list):
            delays = torch.FloatTensor(delays).unsqueeze(0)

        if len(delays.size()) == 1:
            delays = delays.view(1, -1)

        if isinstance(src_lens, list):
            src_lens = torch.FloatTensor(src_lens)
        if isinstance(src_lens, numbers.Number):
            src_lens = torch.FloatTensor([src_lens])
        if len(src_lens.size()) == 1:
            src_lens = src_lens.view(-1, 1)
        src_lens = src_lens.type_as(delays)

        if ref_lens is not None:
            if isinstance(ref_lens, list):
                ref_lens = torch.FloatTensor(ref_lens)
            if isinstance(ref_lens, numbers.Number):
                ref_lens = torch.FloatTensor([ref_lens])
            if len(ref_lens.size()) == 1:
                ref_lens = ref_lens.view(-1, 1)
            ref_lens = ref_lens.type_as(delays)

        if target_padding_mask is not None:
            tgt_lens = delays.size(-1) - target_padding_mask.sum(dim=1)
            delays = delays.masked_fill(target_padding_mask, 0)
        else:
            tgt_lens = torch.ones_like(src_lens) * delays.size(1)

        tgt_lens = tgt_lens.view(-1, 1)

        return delays, src_lens, tgt_lens, ref_lens, target_padding_mask

    def latency_wrapper(
        delays, src_lens, ref_lens=None, target_padding_mask=None
    ):
        delays, src_lens, tgt_lens, ref_lens, target_padding_mask = prepare_latency_metric(
            delays, src_lens, ref_lens, target_padding_mask)
        return func(delays, src_lens, tgt_lens, ref_lens, target_padding_mask)

    return latency_wrapper