in fairnr_cli/train.py [0:0]
def main(args, init_distributed=False):
utils.import_user_module(args)
assert args.max_tokens is not None or args.max_sentences is not None, \
'Must specify batch size either with --max-tokens or --max-sentences'
metrics.reset()
# Initialize CUDA and distributed training
if torch.cuda.is_available() and not args.cpu:
torch.cuda.set_device(args.device_id)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if init_distributed:
args.distributed_rank = distributed_utils.distributed_init(args)
if distributed_utils.is_master(args):
checkpoint_utils.verify_checkpoint_directory(args.save_dir)
# Print args
logger.info(args)
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
# Load valid dataset (we load training data below, based on the latest checkpoint)
for valid_sub_split in args.valid_subset.split(','):
task.load_dataset(valid_sub_split, combine=False, epoch=1)
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
logger.info(model)
logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
logger.info('num. model params: {} (num. trained: {})'.format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
))
# Build trainer
if args.model_parallel_size == 1:
trainer = Trainer(args, task, model, criterion)
else:
trainer = MegatronTrainer(args, task, model, criterion)
task.setup_trainer(trainer)
logger.info('training on {} GPUs'.format(args.distributed_world_size))
logger.info('max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
args.max_sentences,
))
# Load the latest checkpoint if one is available and restore the
# corresponding train iterator
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
lr = trainer.get_lr()
train_meter = meters.StopwatchMeter()
train_meter.start()
valid_subsets = args.valid_subset.split(',')
while (
lr > args.min_lr
and epoch_itr.next_epoch_idx <= max_epoch
):
# train for one epoch
should_end_training = train(args, trainer, task, epoch_itr)
valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
epoch_itr = trainer.get_train_iterator(
epoch_itr.next_epoch_idx,
# sharded data: get train iterator for next epoch
load_dataset=(os.pathsep in getattr(args, 'data', '')),
)
if should_end_training:
break
train_meter.stop()
logger.info('done training in {:.1f} seconds'.format(train_meter.sum))