def main()

in retrieval_train.py [0:0]


def main(opt_):
    if opt_.pretrained:
        net, dictionary = load_model(opt_.pretrained, opt_)
        net.opt.dataset_name = opt_.dataset_name
        net.opt.reddit_folder = opt_.reddit_folder
        net.opt.reactonly = opt_.reactonly
        net.opt.max_hist_len = opt_.max_hist_len
        env = TrainEnvironment(net.opt, dictionary)
        if opt_.cuda:
            net = torch.nn.DataParallel(net.cuda())
        valid_data = env.build_valid_dataloader(False)
        test_data = env.build_valid_dataloader(False, test=True)
        with torch.no_grad():
            logging.info("Validating on the valid set -unshuffled")
            validate(
                0, net, valid_data, is_test=False, nb_candidates=opt_.hits_at_nb_cands
            )
            logging.info("Validating on the hidden test set -unshuffled")
            validate(
                0, net, test_data, is_test=True, nb_candidates=opt_.hits_at_nb_cands
            )
        valid_data = env.build_valid_dataloader(True)
        test_data = env.build_valid_dataloader(True, test=True)
        with torch.no_grad():
            logging.info("Validating on the valid set -shuffle")
            validate(
                0, net, valid_data, is_test=False, nb_candidates=opt_.hits_at_nb_cands
            )
            logging.info("Validating on the hidden test set -shuffle")
            validate(
                0, net, test_data, is_test=True, nb_candidates=opt_.hits_at_nb_cands
            )
    else:
        train_model(opt_)