def main()

in main.py [0:0]


def main(config):

    # Set up MultiGPU training
    if config['multigpu']:
        rank = int(os.environ.get("SLURM_PROCID"))
        n_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES"))
        world_size = int(os.environ.get("SLURM_NTASKS_PER_NODE"))*n_nodes
        local_rank = int(os.environ.get("SLURM_LOCALID"))
        print('LOCAL RANK {}'.format(local_rank))

        dist.init_process_group(
            backend=config['dist_backend'],
            init_method=config['dist_url'],
            world_size=world_size,
            rank=rank
        )

        config['rank'] = rank
        config['local_rank'] = local_rank

    # Load configuration and set up experiment
    if not config['multigpu'] or rank == 0:
        # config['out_dir'] = os.path.join(config['out_dir'], config['exp_name'])
        log_dir = os.path.join(config['out_dir'], 'logs')
        if os.path.exists(log_dir):
            shutil.rmtree(log_dir)
        log = MyLogger(log_dir)
        config_path = os.path.join(config['out_dir'], 'config.json')
        os.makedirs(config['out_dir'], exist_ok=True)
        os.makedirs(os.path.join(config['out_dir'], 'checkpoints'), exist_ok=True)
        print('Output directory prepared')
    else:
        log = None

    # Set device for the model
    if config['multigpu']:
        config['device'] = 'cuda:{}'.format(local_rank)
    else:
        config['device'] = 'cuda' if config['gpu'] else 'cpu'

    # Load dataset
    train_loader, val_loader = get_dataset(config)
    if not config['multigpu'] or rank == 0:
        print('Dataset loaded')

    # Load model
    model = init_model(config)
    if not config['multigpu'] or rank == 0:
        print(model)
        print('Model loaded')

    # Load optimizer
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config['lr'],
        weight_decay=1e-4,
    )
    if not config['multigpu'] or rank == 0:
        print('Optimizer initialized')


    # Set up apex and distributed training
    if config['apex']:
        from apex import amp, parallel
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

        if config['multigpu']:
            with torch.cuda.device(config['local_rank']):
                model = parallel.DistributedDataParallel(model)
            
    else:
        if config['multigpu']:
            model = torch.nn.parallel.DistributedDataParallel(
                model, 
                device_ids=[config['local_rank']], 
                output_device=config['local_rank'],
            )

    # Save updated config
    if not config['multigpu'] or rank == 0:
        configs.save_config(config_path, config)

    # TODO: move this to train functions
    def generate_samples(frames, split):
        all_preds = []
        (preds, targets), _ = train_fns.sample_step(model, config, frames, use_mean=True)
        all_preds.append(preds)
        for idx in range(5):
            (preds, targets), _ = train_fns.sample_step(model, config, frames)
            all_preds.append(preds)
        preds = torch.cat(all_preds, -1)
        video = torch.cat([targets, preds], -1)
        log.video('{}_sample'.format(split), video)

        (preds, targets), _ = train_fns.reconstruction_step(model, config, frames)
        video = torch.cat([targets, preds], -1)
        log.video('{}_reconstruction'.format(split), video)


    # Main loop
    abs_batch_idx = 1
    beta2 = config['beta']
    for epoch_idx in range(config['max_epochs']):

        t1 = time.time()

        # Train iterations
        model.train()
        for batch_idx, train_batch in enumerate(train_loader):

            # if batch_idx == 1:
            #     break

            # Change learning rate
            if config['multigpu']:

                warmup_iters = config['batches_per_epoch']*5/world_size
                lr1 = config['lr']
                lr2 = config['lr']*config['batch_size']*world_size/16
                warmup_step = (lr2 - lr1) / warmup_iters

                if abs_batch_idx < warmup_iters:
                    # Gradual scaling
                    lr = abs_batch_idx*warmup_step + lr1
                else:
                    # Decay learning rate
                    lr = lr2

            else:
                lr = config['lr']

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            # Warmup beta
            if config['beta_wu']:
                if not config['multigpu']:
                    world_size = 1
                beta1 = 0
                bn_epochs = 20
                bwarmup_iters = config['batches_per_epoch']*bn_epochs/world_size
                bwarmup_step = (beta2 - beta1)/bwarmup_iters

                if abs_batch_idx < bwarmup_iters:
                    config['beta'] = abs_batch_idx*bwarmup_step + beta1
                else:
                    config['beta'] = beta2

            t2 = time.time()

            # if batch_idx == 1: break
            frames, example_idxs = train_fns.prepare_batch(train_batch, config)
            train_fns.train_step(
                model,
                config, 
                frames, 
                optimizer,
                batch_idx,
                log
            )

            t3 = time.time()

            # print('TIME {:.4f}/{:.4f}'.format(t2 -t1, t3 - t2))
            t1 = time.time()

            abs_batch_idx += 1

            if not config['multigpu'] or rank == 0:
                log.increase_train()

                if (batch_idx + 1) % config['log_freq'] == 0:
                    log.dump_scalars('train')
                    train_fns.train_print_status(log)

        if not config['multigpu'] or rank == 0:
            log.dump_scalars('train')
            train_fns.train_print_status(log)

        # Validation iterations
        model.eval()
        for batch_idx, val_batch in enumerate(val_loader):

            if batch_idx == config['test_batches']:
                break

            frames, example_idxs = train_fns.prepare_batch(val_batch, config)

            train_fns.test_step(
                model,
                config, 
                frames,
                log
            )

            if not config['multigpu'] or rank == 0:
                log.increase_test()

        # Bookkeeping
        if not config['multigpu'] or rank == 0:

            # Train reconstruction and samples
            frames, example_idxs = train_fns.prepare_batch(train_batch, config)
            generate_samples(frames, 'train')
            log.print('Train samples saved')

            # Test reconstructions and samples
            frames, example_idxs = train_fns.prepare_batch(val_batch, config)
            generate_samples(frames, 'val')
            log.print('Test samples saved')

            # Print information and increase epoch
            log.dump_scalars('test')
            train_fns.test_print_status(log)
            log.increase_epoch()

            # Save model 
            if (epoch_idx + 1) % config['save_freq'] == 0:
                torch.save(model.state_dict(), os.path.join(config['out_dir'], 'checkpoints', '{:0>5d}.pth'.format(epoch_idx + 1)))
                log.print('Model saved')