def _run_benchmark_lookup_jit_for_loop()

in benchmark/benchmark_pytext_vocab.py [0:0]


def _run_benchmark_lookup_jit_for_loop(tokens: Union[List[str], List[List[str]]], vocab, num_iters=1):
    @torch.jit.script
    def _run_benchmark_pytext_script_vocab(toks: List[str], v: PytextScriptVocabulary):
        for token in toks:
            v.lookup_indices_1d([token])

    @torch.jit.script
    def _run_benchmark_experimental_script_vocab(toks: List[str], v: ExperimentalScriptVocabulary):
        for token in toks:
            v[token]

    @torch.jit.script
    def _run_benchmark_lists_pytext_script_vocab(tok_lists: List[List[str]], v: PytextScriptVocabulary):
        for tokens_list in tok_lists:
            v.lookup_indices_1d(tokens_list)

    @torch.jit.script
    def _run_benchmark_lists_experimental_script_vocab(tok_lists: List[List[str]], v: ExperimentalScriptVocabulary):
        for tokens_list in tok_lists:
            v.lookup_indices_1d(tokens_list)

    t0 = time.monotonic()
    # list lookup
    if isinstance(tokens, list) and isinstance(tokens[0], list):
        if isinstance(vocab, PytextScriptVocabulary):
            for _ in range(num_iters):
                _run_benchmark_lists_pytext_script_vocab(tokens, vocab)
        elif isinstance(vocab, (ExperimentalScriptVocabulary, torch.jit._script.RecursiveScriptModule)):

            for _ in range(num_iters):
                _run_benchmark_lists_experimental_script_vocab(tokens, vocab)
        else:
            raise RuntimeError("Received vocab of incorrect type {}.".format(type(vocab)))
    # single token lookup
    elif isinstance(tokens, list):
        if isinstance(vocab, PytextScriptVocabulary):
            for _ in range(num_iters):
                _run_benchmark_pytext_script_vocab(tokens, vocab)
        elif isinstance(vocab, (ExperimentalScriptVocabulary, torch.jit._script.RecursiveScriptModule)):
            for _ in range(num_iters):
                _run_benchmark_experimental_script_vocab(tokens, vocab)
        else:
            raise RuntimeError("Received vocab of incorrect type {}.".format(type(vocab)))
    else:
        raise RuntimeError("Received tokens of incorrect type {}.".format(type(tokens)))

    print("Lookup time:", time.monotonic() - t0)