in c3dm/tools/cache_preds.py [0:0]
def cache_preds(model,
loader,
cache_vars=None,
stats=None,
n_extract=None,
cat=True,
eval_mode=True,
strict_mode=False,
):
print("caching model predictions: %s" % str(cache_vars) )
if eval_mode:
model.eval()
else:
print('TRAINING EVAL MODE!!!')
model.train()
trainmode = 'test'
t_start = time.time()
iterator = loader.__iter__()
cached_preds = []
cache_size = 0. # in GB ... counts only cached tensor sizes
n_batches = len(loader)
if n_extract is not None:
n_batches = n_extract
with tqdm(total=n_batches,file=sys.stdout) as pbar:
for it, batch in enumerate(loader):
last_iter = it==n_batches-1
# move to gpu and cast to Var
net_input = get_net_input(batch)
with torch.no_grad():
preds = model(**net_input)
if strict_mode:
assert not any( k in preds for k in net_input.keys() )
preds.update(net_input) # merge everything into one big dict
# if True:
# model.visualize('ff_debug', 'eval', preds, None, clear_env=False)
# import pdb; pdb.set_trace()
if stats is not None:
stats.update(preds,time_start=t_start,stat_set=trainmode)
assert stats.it[trainmode]==it, "inconsistent stat iteration number!"
# restrict the variables to cache
if cache_vars is not None:
preds = {k:preds[k] for k in cache_vars if k in preds}
# ... gather and log the size of the cache
preds, preds_size = gather_all(preds)
cache_size += preds_size
# for k in preds:
# if has_method(preds[k],'cuda'):
# preds[k] = preds[k].data.cpu()
# cache_size += preds[k].numpy().nbytes / 1e9
cached_preds.append(preds)
pbar.set_postfix(cache_size="%1.2f GB"%cache_size)
pbar.update(1)
if last_iter and n_extract is not None:
break
if cat:
return concatenate_cache( cached_preds )
else:
return cached_preds