def ngram_ctc()

in benchmarks/transducer_benchmark.py [0:0]


def ngram_ctc():
    N = 81
    T = 250
    L = 44
    B = 1
    if len(sys.argv) > 1:
        B = int(sys.argv[1])

    tokens = [(i,) for i in range(N)]
    graphemes_to_index = {i : i for i in range(N)}

    ITERATIONS = 20
    inputs = torch.randn(B, T, N, dtype=torch.float, requires_grad=True)

    targets = [tgt.squeeze()
        for tgt in torch.randint(N, size=(B, L)).split(1)]

    for ngram in [0, 1, 2]:
        crit = transducer.Transducer(
            tokens, graphemes_to_index,
            ngram=ngram, blank="optional",
            allow_repeats=False, reduction="mean"
        )
        def fwd_bwd():
            loss = crit(inputs, targets)
            loss.backward()
        time_func(
            fwd_bwd, iterations=20, name=f"ctc fwd + bwd, ngram={ngram}")
        def viterbi():
            crit.viterbi(inputs)
        time_func(
            viterbi, iterations=20, name=f"ctc viterbi, ngram={ngram}")