in src/train.py [0:0]
def rollout(model, dataset, opts):
# Put in greedy evaluation mode!
set_decode_type(model, "greedy")
model.eval()
def eval_model_bat(bat):
with torch.no_grad():
cost, _ = model(move_to(bat['nodes'], opts.device), move_to(bat['graph'], opts.device))
return cost.data.cpu()
return torch.cat([
eval_model_bat(bat)
for bat in tqdm(
DataLoader(dataset, batch_size=opts.batch_size, shuffle=False, num_workers=opts.num_workers),
disable=opts.no_progress_bar, ascii=True
)
], 0)