in retrieval_train.py [0:0]
def main(opt_):
if opt_.pretrained:
net, dictionary = load_model(opt_.pretrained, opt_)
net.opt.dataset_name = opt_.dataset_name
net.opt.reddit_folder = opt_.reddit_folder
net.opt.reactonly = opt_.reactonly
net.opt.max_hist_len = opt_.max_hist_len
env = TrainEnvironment(net.opt, dictionary)
if opt_.cuda:
net = torch.nn.DataParallel(net.cuda())
valid_data = env.build_valid_dataloader(False)
test_data = env.build_valid_dataloader(False, test=True)
with torch.no_grad():
logging.info("Validating on the valid set -unshuffled")
validate(
0, net, valid_data, is_test=False, nb_candidates=opt_.hits_at_nb_cands
)
logging.info("Validating on the hidden test set -unshuffled")
validate(
0, net, test_data, is_test=True, nb_candidates=opt_.hits_at_nb_cands
)
valid_data = env.build_valid_dataloader(True)
test_data = env.build_valid_dataloader(True, test=True)
with torch.no_grad():
logging.info("Validating on the valid set -shuffle")
validate(
0, net, valid_data, is_test=False, nb_candidates=opt_.hits_at_nb_cands
)
logging.info("Validating on the hidden test set -shuffle")
validate(
0, net, test_data, is_test=True, nb_candidates=opt_.hits_at_nb_cands
)
else:
train_model(opt_)