def cache_preds()

in tools/cache_preds.py [0:0]


def cache_preds(model, loader, cache_vars=None, stats=None, n_extract=None):

    print("caching model predictions: %s" % str(cache_vars))

    model.eval()

    trainmode = 'test'

    t_start = time.time()

    cached_preds = []

    cache_size = 0.  # in GB ... counts only cached tensor sizes

    n_batches = len(loader)
    if n_extract is not None:
        n_batches = n_extract

    with tqdm(total=n_batches, file=sys.stdout) as pbar:
        for it, batch in enumerate(loader):

            last_iter = it == n_batches-1

            # move to gpu and cast to Var
            net_input = get_net_input(batch)

            with torch.no_grad():
                preds = model(**net_input)

            assert not any(k in preds for k in net_input.keys())
            preds.update(net_input)  # merge everything into one big dict

            if stats is not None:
                stats.update(preds, time_start=t_start, stat_set=trainmode)
                assert stats.it[trainmode] == it, \
                    "inconsistent stat iteration number!"

            # restrict the variables to cache
            if cache_vars is not None:
                preds = {k: preds[k] for k in cache_vars if k in preds}

            # ... gather and log the size of the cache
            preds, preds_size = gather_all(preds)
            cache_size += preds_size

            cached_preds.append(preds)

            pbar.set_postfix(cache_size="%1.2f GB" % cache_size)
            pbar.update(1)

            if last_iter and n_extract is not None:
                break

    cached_preds_cat = concatenate_cache(cached_preds)

    return cached_preds_cat