in mdl.py [0:0]
def main(args, init_distributed=False):
utils.import_user_module(args)
# Initialize CUDA and distributed training
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Setup task, (should be default, translation)
task = tasks.setup_task(args)
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
# Build trainer
trainer = Trainer(args, task, model, criterion)
initial_state_checkpoint = str(pathlib.Path(args.save_dir) / 'initial.pt')
trainer.save_checkpoint(initial_state_checkpoint, {'epoch': 0})
batches_per_epoch = args.mdl_batches_per_epoch
batch_size = args.mdl_batch_size
block_size = args.mdl_block_size
epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=True)
examples = list(range(len(epoch_itr.dataset)))
train_examples = examples[:args.mdl_train_examples]
test_examples = examples[args.mdl_train_examples:]
random.shuffle(test_examples)
blocks = [train_examples]
blocks += [test_examples[i:i + block_size] for i in range(0, len(test_examples), block_size)]
allowed_examples = []
steps = len(blocks)
block_cross_entropys = []
for step in range(steps):
trainer.load_checkpoint(initial_state_checkpoint, reset_optimizer=True, reset_lr_scheduler=True)
epoch_itr = trainer.get_train_iterator(epoch=step, load_dataset=False)
allowed_examples += blocks[step]
# if mdl-batch-size is set, we sample batches with replacement,
# otherwise, each batch contains all allowed_examples
if batch_size:
batches = tuple([random.choices(allowed_examples, k=batch_size) for _ in range(batches_per_epoch)])
else:
batches = tuple([allowed_examples for _ in range(batches_per_epoch)])
epoch_itr.frozen_batches = batches
train(args, trainer, task, epoch_itr)
stashed_criterion = trainer.criterion
train.criterion = CRITERION_REGISTRY['cross_entropy'](args, task)
if step < steps - 1:
stashed_criterion = trainer.criterion
train.criterion = CRITERION_REGISTRY['cross_entropy'](args, task)
next_block = (blocks[step + 1], )
next_block_cross_entropy = validate(args, trainer, task, epoch_itr, subsets=['train'], \
allowed_batches=next_block)
train.criterion = stashed_criterion
block_cross_entropys.append(next_block_cross_entropy)
trainer.set_num_updates(0) #reset the num_update as not systematically updated in load_checkpoint
state_checkpoint = str(pathlib.Path(args.save_dir) / f'{step}.pt')
trainer.save_checkpoint(state_checkpoint, {'epoch': step})
examples_seen = [len(b) for b in blocks]
cross_entropy_sum = sum(n_examples * mean_cross_entropy for n_examples, mean_cross_entropy in zip(examples_seen[1:], block_cross_entropys))
stats = dict(online_cross_entropy=block_cross_entropys,
description_length=cross_entropy_sum,
examples_seen=examples_seen)
print(json.dumps(stats))
state_checkpoint = str(pathlib.Path(args.save_dir) / 'last.pt')
trainer.save_checkpoint(state_checkpoint, {'epoch': step})