def _run_rl()

in src/run.py [0:0]


def _run_rl(opts):
    
    # distributed args
    opts.world_size = dist.get_world_size()
    opts.rank = rank = dist.get_rank()
    opts.local_rank = local_rank = dist.get_local_rank()
    opts.batch_size //= opts.world_size // 8
    opts.batch_size = max(opts.batch_size, 1)

    # Pretty print the run args
    pp.pprint(vars(opts))

    # Set the random seed
    #torch.manual_seed(opts.seed)
    #np.random.seed(opts.seed)
    torch.manual_seed(opts.rank)
    np.random.seed(opts.rank)

    # Optionally configure tensorboard
    tb_logger = None
    if not opts.no_tensorboard:
        tb_logger = TbLogger(os.path.join(
            opts.log_dir, "{}_{}-{}".format(opts.problem, opts.min_size, opts.max_size), opts.run_name))

    if opts.rank == 0:
        #os.makedirs(opts.save_dir)
        # Save arguments so exact configuration can always be found
        with open(os.path.join(opts.model_dir, "args.json"), 'w') as f:
            json.dump(vars(opts), f, indent=True)

    # Set the device
    #opts.device = torch.device("cuda:0" if opts.use_cuda else "cpu")
    opts.device = torch.device("cuda")

    # Figure out what's the problem
    problem = load_problem(opts.problem)

    # Load data from load_path
    load_data = {}
    assert opts.load_path is None or opts.resume is None, "Only one of load path and resume can be given"
    load_path = opts.load_path if opts.load_path is not None else opts.resume
    if load_path is not None:
        print('\nLoading data from {}'.format(load_path))
        load_data = torch_load_cpu(load_path)

    # Initialize model
    model_class = {
        'attention': AttentionModel,
        'nar': NARModel,
        # 'pointer': PointerNetwork
    }.get(opts.model, None)
    assert model_class is not None, "Unknown model: {}".format(model_class)
    encoder_class = {
        'gnn': GNNEncoder,
        'gat': GraphAttentionEncoder,
        'mlp': MLPEncoder
    }.get(opts.encoder, None)
    assert encoder_class is not None, "Unknown encoder: {}".format(encoder_class)
    model = DDP(model_class(
        problem=problem,
        embedding_dim=opts.embedding_dim,
        encoder_class=encoder_class,
        n_encode_layers=opts.n_encode_layers,
        aggregation=opts.aggregation,
        aggregation_graph=opts.aggregation_graph,
        normalization=opts.normalization,
        learn_norm=opts.learn_norm,
        track_norm=opts.track_norm,
        gated=opts.gated,
        n_heads=opts.n_heads,
        tanh_clipping=opts.tanh_clipping,
        mask_inner=True,
        mask_logits=True,
        mask_graph=False,
        checkpoint_encoder=opts.checkpoint_encoder,
        shrink_size=opts.shrink_size
    ).to(opts.device))
    torch.cuda.set_device(local_rank)
    model.cuda(local_rank)

    #if opts.use_cuda and torch.cuda.device_count() > 1:
    #    model = torch.nn.DataParallel(model)
    
        
    ## Compute number of network parameters
    #print(model)
    #nb_param = 0
    #for param in model.parameters():
    #    nb_param += np.prod(list(param.data.size()))
    #print('Number of parameters: ', nb_param)
#
    ## Overwrite model parameters by parameters to load
    #model_ = get_inner_model(model)
    #model_.load_state_dict({**model_.state_dict(), **load_data.get('model', {})})

    # Initialize baseline
    if opts.baseline == 'exponential':
        baseline = ExponentialBaseline(opts.exp_beta)
    
    elif opts.baseline == 'critic' or opts.baseline == 'critic_lstm':
        assert problem.NAME == 'tsp', "Critic only supported for TSP"
        baseline = CriticBaseline(
            DDP(
                CriticNetwork(
                    embedding_dim=opts.embedding_dim,
                    encoder_class=encoder_class,
                    n_encode_layers=opts.n_encode_layers,
                    aggregation=opts.aggregation,
                    normalization=opts.normalization,
                    learn_norm=opts.learn_norm,
                    track_norm=opts.track_norm,
                    gated=opts.gated,
                    n_heads=opts.n_heads
                )
            .to(opts.device))
        )
        
        print(baseline.critic)
        nb_param = 0
        for param in baseline.get_learnable_parameters():
            nb_param += np.prod(list(param.data.size()))
        print('Number of parameters (BL): ', nb_param)
        
    elif opts.baseline == 'rollout':
        baseline = RolloutBaseline(model, problem, opts)
    
    else:
        assert opts.baseline is None, "Unknown baseline: {}".format(opts.baseline)
        baseline = NoBaseline()

    if opts.bl_warmup_epochs > 0:
        baseline = WarmupBaseline(baseline, opts.bl_warmup_epochs, warmup_exp_beta=opts.exp_beta)

    # Load baseline from data, make sure script is called with same type of baseline
    if 'baseline' in load_data:
        baseline.load_state_dict(load_data['baseline'])

    # Initialize optimizer
    optimizer = optim.Adam(
        [{'params': model.parameters(), 'lr': opts.lr_model}]
        + (
            [{'params': baseline.get_learnable_parameters(), 'lr': opts.lr_critic}]
            if len(baseline.get_learnable_parameters()) > 0
            else []
        )
    )

    # Load optimizer state
    if 'optimizer' in load_data:
        optimizer.load_state_dict(load_data['optimizer'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(opts.device)

    # Initialize learning rate scheduler, decay by lr_decay once per epoch!
    lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: opts.lr_decay ** epoch)

    # Load/generate datasets
    val_datasets = []
    for val_filename in opts.val_datasets:
        val_datasets.append(
            problem.make_dataset(
                filename=val_filename, batch_size=opts.batch_size, num_samples=opts.val_size, 
                neighbors=opts.neighbors, knn_strat=opts.knn_strat, supervised=True, nar=False
            ))

    if opts.resume:
        epoch_resume = int(os.path.splitext(os.path.split(opts.resume)[-1])[0].split("-")[1])

        torch.set_rng_state(load_data['rng_state'])
        if opts.use_cuda:
            torch.cuda.set_rng_state_all(load_data['cuda_rng_state'])
        # Set the random states
        # Dumping of state was done before epoch callback, so do that now (model is loaded)
        baseline.epoch_callback(model, epoch_resume)
        print("Resuming after {}".format(epoch_resume))
        opts.epoch_start = epoch_resume + 1

    # Start training loop
    for epoch in range(opts.epoch_start, opts.epoch_start + opts.n_epochs):
        ## Generate new training data for each epoch
        #train_dataset = baseline.wrap_dataset(
        #    problem.make_dataset(
        #        min_size=opts.min_size, max_size=opts.max_size, batch_size=opts.batch_size, 
        #        num_samples=opts.epoch_size, distribution=opts.data_distribution, 
        #        neighbors=opts.neighbors, knn_strat=opts.knn_strat
        #    ))
        #train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
        #                                                                num_replicas=opts.world_size,
        #                                                                rank=rank
        #)
        #train_dataloader = DataLoader(
        #    train_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=opts.num_workers,
        #      pin_memory=True,
        #      sampler=train_sampler)
    
        train_epoch(
            model,
            optimizer,
            baseline,
            lr_scheduler,
            epoch,
            #train_dataloader,
            val_datasets,
            problem,
            tb_logger,
            opts
        )