def main()

in mdl.py [0:0]


def main(args, init_distributed=False):
    utils.import_user_module(args)

    # Initialize CUDA and distributed training
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Setup task, (should be default, translation)
    task = tasks.setup_task(args)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    initial_state_checkpoint = str(pathlib.Path(args.save_dir) / 'initial.pt')
    trainer.save_checkpoint(initial_state_checkpoint, {'epoch': 0})

    batches_per_epoch = args.mdl_batches_per_epoch
    batch_size = args.mdl_batch_size
    block_size = args.mdl_block_size
    
    epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=True)

    examples = list(range(len(epoch_itr.dataset)))

    train_examples = examples[:args.mdl_train_examples]
    test_examples = examples[args.mdl_train_examples:]

    random.shuffle(test_examples)
    blocks =  [train_examples]
    blocks += [test_examples[i:i + block_size] for i in range(0, len(test_examples), block_size)]

    allowed_examples = []
    steps = len(blocks)
    block_cross_entropys = []

    for step in range(steps):
        trainer.load_checkpoint(initial_state_checkpoint, reset_optimizer=True, reset_lr_scheduler=True)

        epoch_itr = trainer.get_train_iterator(epoch=step, load_dataset=False)

        allowed_examples += blocks[step]

        # if mdl-batch-size is set, we sample batches with replacement,
        # otherwise, each batch contains all allowed_examples
        if batch_size:
            batches = tuple([random.choices(allowed_examples, k=batch_size) for _ in range(batches_per_epoch)])
        else:
            batches = tuple([allowed_examples for _ in range(batches_per_epoch)])

        epoch_itr.frozen_batches = batches

        train(args, trainer, task, epoch_itr)

        stashed_criterion = trainer.criterion
        train.criterion = CRITERION_REGISTRY['cross_entropy'](args, task)
        
        if step < steps - 1:
            stashed_criterion = trainer.criterion
            train.criterion = CRITERION_REGISTRY['cross_entropy'](args, task)
            next_block = (blocks[step + 1], )
            next_block_cross_entropy = validate(args, trainer, task, epoch_itr, subsets=['train'], \
                allowed_batches=next_block)
            train.criterion = stashed_criterion
            block_cross_entropys.append(next_block_cross_entropy)

        trainer.set_num_updates(0) #reset the num_update as not systematically updated in load_checkpoint
        state_checkpoint = str(pathlib.Path(args.save_dir) / f'{step}.pt')
        trainer.save_checkpoint(state_checkpoint, {'epoch': step})


    examples_seen = [len(b) for b in blocks]
    cross_entropy_sum = sum(n_examples * mean_cross_entropy for n_examples, mean_cross_entropy in zip(examples_seen[1:], block_cross_entropys))
    stats = dict(online_cross_entropy=block_cross_entropys,
                description_length=cross_entropy_sum,
                examples_seen=examples_seen)
    print(json.dumps(stats))
    
    state_checkpoint = str(pathlib.Path(args.save_dir) / 'last.pt')
    trainer.save_checkpoint(state_checkpoint, {'epoch': step})