def train_epoch()

in src/train.py [0:0]


def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_datasets, problem, tb_logger, opts):
    print("\nStart train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name))
    step = epoch * (opts.epoch_size // opts.batch_size)
    start_time = time.time()

    if not opts.no_tensorboard:
        tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step)
    
    # Generate new training data for each epoch
    train_dataset = baseline.wrap_dataset(
        problem.make_dataset(
            min_size=opts.min_size, max_size=opts.max_size, batch_size=opts.batch_size, 
            num_samples=opts.epoch_size, distribution=opts.data_distribution, 
            neighbors=opts.neighbors, knn_strat=opts.knn_strat
        ))
    train_dataloader = DataLoader(
        train_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=opts.num_workers)

    # Put model in train mode!
    model.train()
    optimizer.zero_grad()
    set_decode_type(model, "sampling")

    #for batch_id, batch in enumerate(tqdm(train_dataloader, disable=opts.no_progress_bar, ascii=True)):
    for batch_id, batch in enumerate(train_dataloader):

        train_batch(
            model,
            optimizer,
            baseline,
            epoch,
            batch_id,
            step,
            batch,
            tb_logger,
            opts
        )

        step += 1
    
    lr_scheduler.step(epoch)

    epoch_duration = time.time() - start_time
    print("Finished epoch {}, took {} s".format(epoch, time.strftime('%H:%M:%S', time.gmtime(epoch_duration))))

    if opts.rank == 0:
        #if (opts.checkpoint_epochs != 0 and epoch % opts.checkpoint_epochs == 0) or epoch == opts.n_epochs - 1:
        #    print('Saving model and state...')
        #    torch.save(
        #        {
        #            'model': get_inner_model(model).state_dict(),
        #            'optimizer': optimizer.state_dict(),
        #            'rng_state': torch.get_rng_state(),
        #            'cuda_rng_state': torch.cuda.get_rng_state_all(),
        #            'baseline': baseline.state_dict()
        #        },
        #        os.path.join(opts.save_dir, 'epoch-{}.pt'.format(epoch))
        #    )
        if epoch == opts.n_epochs - 1:
            #with open(os.path.join(opts.model_dir, 'model.pt'), 'wb') as f:
            print("saving model and state")
            #print(get_inner_model(model))
            #print(type(get_inner_model(model)))
            #print(type(optimizer))
            #print(type(torch.get_rng_state()))
            torch.save(       
                {
                    'model': get_inner_model(model).state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'rng_state': torch.get_rng_state(),
                    'cuda_rng_state': torch.cuda.get_rng_state_all(),
                    #'baseline': baseline.state_dict()
                }, 
                os.path.join(opts.model_dir, 'model.pt')
            )

        for val_idx, val_dataset in enumerate(val_datasets):
            avg_reward, avg_opt_gap = validate(model, val_dataset, problem, opts)
            if not opts.no_tensorboard:
                tb_logger.log_value('val{}/avg_reward'.format(val_idx+1), avg_reward, step)
                tb_logger.log_value('val{}/opt_gap'.format(val_idx+1), avg_opt_gap, step)
    
    baseline.epoch_callback(model, epoch)