in benchmarks/transducer_benchmark.py [0:0]
def word_decompositions():
tokens_path = "word_pieces_tokens_1000.txt"
with open(tokens_path, "r") as fid:
tokens = sorted([l.strip() for l in fid])
graphemes = sorted(set(c for t in tokens for c in t))
graphemes_to_index = {t: i for i, t in enumerate(graphemes)}
N = len(tokens) + 1
T = 100
L = 15
B = 1
if len(sys.argv) > 1:
B = int(sys.argv[1])
inputs = torch.randn(B, T, N, dtype=torch.float, requires_grad=True)
if torch.cuda.is_available():
inputs = inputs.cuda()
targets = []
for b in range(B):
pieces = (random.choice(tokens) for l in range(L))
target = [graphemes_to_index[l] for wp in pieces for l in wp]
targets.append(torch.tensor(target))
crit = transducer.Transducer(
tokens, graphemes_to_index, blank="optional", allow_repeats=False, reduction="mean"
)
def fwd_bwd():
loss = crit(inputs, targets)
loss.backward()
time_func(fwd_bwd, 20, "word decomps fwd + bwd")
def viterbi():
crit.viterbi(inputs)
time_func(viterbi, 20, "word decomps viterbi")