def _run_benchmark_lookup()

in benchmark/benchmark_pytext_vocab.py [0:0]


def _run_benchmark_lookup(tokens, vocab, num_iters=1):
    def _run_benchmark_pytext_vocab(toks, v: PytextVocabulary):
        for token_or_tokens_list in toks:
            v.lookup_all(token_or_tokens_list)

    def _run_benchmark_pytext_script_vocab(toks, v: PytextScriptVocabulary):
        # list lookup
        if isinstance(toks, list) and isinstance(toks[0], list):
            for tokens_list in toks:
                v.lookup_indices_1d(tokens_list)
        # single token lookup
        elif isinstance(toks, list):
            for token in toks:
                v.lookup_indices_1d([token])
        else:
            raise RuntimeError("Received tokens of incorrect type {}.".format(type(toks)))

    def _run_benchmark_experimental_script_vocab(toks, v: ExperimentalScriptVocabulary):
        # list lookup
        if isinstance(toks, list) and isinstance(toks[0], list):
            for tokens_list in toks:
                v.lookup_indices_1d(tokens_list)
        # single token lookup
        elif isinstance(toks, list):
            for token in toks:
                v[token]
        else:
            raise RuntimeError("Received tokens of incorrect type {}.".format(type(toks)))

    t0 = time.monotonic()
    if isinstance(vocab, PytextVocabulary):
        for _ in range(num_iters):
            _run_benchmark_pytext_vocab(tokens, vocab)
    elif isinstance(vocab, PytextScriptVocabulary):
        for _ in range(num_iters):
            _run_benchmark_pytext_script_vocab(tokens, vocab)
    elif isinstance(vocab, (ExperimentalScriptVocabulary, torch.jit._script.RecursiveScriptModule)):
        for _ in range(num_iters):
            _run_benchmark_experimental_script_vocab(tokens, vocab)
    else:
        raise RuntimeError("Received vocab of incorrect type {}.".format(type(vocab)))

    print("Lookup time:", time.monotonic() - t0)