def benchmark_new_vocab_lookup()

in benchmark/benchmark_vocab.py [0:0]


def benchmark_new_vocab_lookup(vocab_file_path=None, dataset='AG_NEWS'):
    def _run_benchmark_lookup(tokens, vocab):
        t0 = time.monotonic()
        # list lookup
        if isinstance(tokens, list) and isinstance(tokens[0], list):
            for tokens_list in tokens:
                vocab.lookup_indices(tokens_list)
        # single token lookup
        elif isinstance(tokens, list):
            for token in tokens:
                vocab[token]
        else:
            raise RuntimeError("Received tokens of incorrect type {}.".format(type(tokens)))
        print("Lookup time:", time.monotonic() - t0)

    tokens = []
    tokens_lists = []
    tokenizer = get_tokenizer("basic_english")
    for (_, text) in DATASETS[dataset](split='train'):
        cur_tokens = tokenizer(text)
        tokens_lists.append(cur_tokens)
        tokens += cur_tokens

    if vocab_file_path:
        print("Loading Vocab from file {}".format(vocab_file_path))

        def token_iterator(file_path):
            f = open(file_path, 'r')
            for token in f:
                yield token

        # new Vocab construction
        print("Vocab New")
        t0 = time.monotonic()
        f = open(vocab_file_path, 'r')
        v_new = load_vocab_from_file(f)
        print("Construction time:", time.monotonic() - t0)
    else:
        print("Loading Vocab from {}".format(dataset))
        counter = Counter(tokens)
        sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        ordered_dict = OrderedDict(sorted_by_freq_tuples)

        # new Vocab construction
        print("Vocab New")
        t0 = time.monotonic()
        v_new = VocabNew(ordered_dict)
        print("Construction time:", time.monotonic() - t0)
    jit_v_new = torch.jit.script(v_new)

    # new Vocab eager lookup
    print("Vocab New - Eager Mode")
    _run_benchmark_lookup(tokens, v_new)
    _run_benchmark_lookup([tokens], v_new)
    _run_benchmark_lookup(tokens_lists, v_new)

    jit_v_new = torch.jit.script(v_new)
    # new Vocab jit lookup
    print("Vocab New - Jit Mode")
    _run_benchmark_lookup(tokens, jit_v_new)
    _run_benchmark_lookup([tokens], jit_v_new)
    _run_benchmark_lookup(tokens_lists, jit_v_new)