in bring-your-own-container/fairseq_translation/fairseq/train_driver.py [0:0]
def main(args):
if args.max_tokens is None:
args.max_tokens = 6000
print(args)
if not torch.cuda.is_available():
raise NotImplementedError("Training on CPU is not supported")
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
# Load dataset splits
load_dataset_splits(task, ["train", "valid"])
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
print("| model {}, criterion {}".format(args.arch, criterion.__class__.__name__))
print("| num. model params: {}".format(sum(p.numel() for p in model.parameters())))
# Make a dummy batch to (i) warm the caching allocator and (ii) as a
# placeholder DistributedDataParallel when there's an uneven number of
# batches per worker.
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
dummy_batch = task.dataset("train").get_dummy_batch(args.max_tokens, max_positions)
# Build trainer
trainer = Trainer(args, task, model, criterion, dummy_batch)
print("| training on {} GPUs".format(args.distributed_world_size))
print(
"| max tokens per GPU = {} and max sentences per GPU = {}".format(
args.max_tokens,
args.max_sentences,
)
)
# Initialize dataloader
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=8,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
)
# Load the latest checkpoint if one is available
if not load_checkpoint(args, trainer, epoch_itr):
trainer.dummy_train_step([dummy_batch])
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(",")
while (
lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update
):
# train for one epoch
train(args, trainer, task, epoch_itr)
if epoch_itr.epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
# save checkpoint
if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
train_meter.stop()
print("| done training in {:.1f} seconds".format(train_meter.sum))