def word_decompositions()

in benchmarks/transducer_benchmark.py [0:0]


def word_decompositions():
    tokens_path = "word_pieces_tokens_1000.txt"
    with open(tokens_path, "r") as fid:
        tokens = sorted([l.strip() for l in fid])
    graphemes = sorted(set(c for t in tokens for c in t))
    graphemes_to_index = {t: i for i, t in enumerate(graphemes)}

    N = len(tokens) + 1
    T = 100
    L = 15
    B = 1
    if len(sys.argv) > 1:
        B = int(sys.argv[1])

    inputs = torch.randn(B, T, N, dtype=torch.float, requires_grad=True)
    if torch.cuda.is_available():
        inputs = inputs.cuda()

    targets = []
    for b in range(B):
        pieces = (random.choice(tokens) for l in range(L))
        target = [graphemes_to_index[l] for wp in pieces for l in wp]
        targets.append(torch.tensor(target))

    crit = transducer.Transducer(
        tokens, graphemes_to_index, blank="optional", allow_repeats=False, reduction="mean"
    )

    def fwd_bwd():
        loss = crit(inputs, targets)
        loss.backward()
    time_func(fwd_bwd, 20, "word decomps fwd + bwd")

    def viterbi():
        crit.viterbi(inputs)
    time_func(viterbi, 20, "word decomps viterbi")