def cache_preds()

in c3dm/tools/cache_preds.py [0:0]


def cache_preds(model, 
                loader, 
                cache_vars=None, 
                stats=None,
                n_extract=None, 
                cat=True, 
                eval_mode=True,
                strict_mode=False,
                ):

    print("caching model predictions: %s" % str(cache_vars) )
    
    if eval_mode:
        model.eval()
    else:
        print('TRAINING EVAL MODE!!!')
        model.train()
    
    trainmode = 'test'

    t_start = time.time()

    iterator = loader.__iter__()

    cached_preds = []

    cache_size = 0. # in GB ... counts only cached tensor sizes

    n_batches = len(loader)
    if n_extract is not None:
        n_batches = n_extract
    
    with tqdm(total=n_batches,file=sys.stdout) as pbar:
        for it, batch in enumerate(loader):

            last_iter = it==n_batches-1

            # move to gpu and cast to Var
            net_input = get_net_input(batch)
            
            with torch.no_grad():
                preds = model(**net_input)

            if strict_mode:
                assert not any( k in preds for k in net_input.keys() ) 
            preds.update(net_input) # merge everything into one big dict        
            
            # if True:
            #     model.visualize('ff_debug', 'eval', preds, None, clear_env=False)
            #     import pdb; pdb.set_trace()

            if stats is not None:
                stats.update(preds,time_start=t_start,stat_set=trainmode)
                assert stats.it[trainmode]==it, "inconsistent stat iteration number!"                            

            # restrict the variables to cache
            if cache_vars is not None:
                preds = {k:preds[k] for k in cache_vars if k in preds}

            # ... gather and log the size of the cache
            preds, preds_size = gather_all(preds)
            cache_size += preds_size

            # for k in preds:                
            #     if has_method(preds[k],'cuda'):
            #         preds[k] = preds[k].data.cpu()
            #         cache_size += preds[k].numpy().nbytes / 1e9

            cached_preds.append(preds)

            pbar.set_postfix(cache_size="%1.2f GB"%cache_size)
            pbar.update(1)

            if last_iter and n_extract is not None:
                break

    if cat:
        return concatenate_cache( cached_preds )
    else:
        return cached_preds