in tools/cache_preds.py [0:0]
def cache_preds(model, loader, cache_vars=None, stats=None, n_extract=None):
print("caching model predictions: %s" % str(cache_vars))
model.eval()
trainmode = 'test'
t_start = time.time()
cached_preds = []
cache_size = 0. # in GB ... counts only cached tensor sizes
n_batches = len(loader)
if n_extract is not None:
n_batches = n_extract
with tqdm(total=n_batches, file=sys.stdout) as pbar:
for it, batch in enumerate(loader):
last_iter = it == n_batches-1
# move to gpu and cast to Var
net_input = get_net_input(batch)
with torch.no_grad():
preds = model(**net_input)
assert not any(k in preds for k in net_input.keys())
preds.update(net_input) # merge everything into one big dict
if stats is not None:
stats.update(preds, time_start=t_start, stat_set=trainmode)
assert stats.it[trainmode] == it, \
"inconsistent stat iteration number!"
# restrict the variables to cache
if cache_vars is not None:
preds = {k: preds[k] for k in cache_vars if k in preds}
# ... gather and log the size of the cache
preds, preds_size = gather_all(preds)
cache_size += preds_size
cached_preds.append(preds)
pbar.set_postfix(cache_size="%1.2f GB" % cache_size)
pbar.update(1)
if last_iter and n_extract is not None:
break
cached_preds_cat = concatenate_cache(cached_preds)
return cached_preds_cat