in benchmarks/transducer_benchmark.py [0:0]
def ngram_ctc():
N = 81
T = 250
L = 44
B = 1
if len(sys.argv) > 1:
B = int(sys.argv[1])
tokens = [(i,) for i in range(N)]
graphemes_to_index = {i : i for i in range(N)}
ITERATIONS = 20
inputs = torch.randn(B, T, N, dtype=torch.float, requires_grad=True)
targets = [tgt.squeeze()
for tgt in torch.randint(N, size=(B, L)).split(1)]
for ngram in [0, 1, 2]:
crit = transducer.Transducer(
tokens, graphemes_to_index,
ngram=ngram, blank="optional",
allow_repeats=False, reduction="mean"
)
def fwd_bwd():
loss = crit(inputs, targets)
loss.backward()
time_func(
fwd_bwd, iterations=20, name=f"ctc fwd + bwd, ngram={ngram}")
def viterbi():
crit.viterbi(inputs)
time_func(
viterbi, iterations=20, name=f"ctc viterbi, ngram={ngram}")