in main_train.py [0:0]
def main(opts, **kwargs):
num_gpus = getattr(opts, "dev.num_gpus", 0) # defaults are for CPU
dev_id = getattr(opts, "dev.device_id", torch.device('cpu'))
device = getattr(opts, "dev.device", torch.device('cpu'))
is_distributed = getattr(opts, "ddp.use_distributed", False)
is_master_node = is_master(opts)
# set-up data loaders
train_loader, val_loader, train_sampler = create_train_val_loader(opts)
# compute max iterations based on max epochs
# Useful in doing polynomial decay
is_iteration_based = getattr(opts, "scheduler.is_iteration_based", False)
if is_iteration_based:
max_iter = getattr(opts, "scheduler.max_iterations", DEFAULT_ITERATIONS)
if max_iter is None or max_iter <= 0:
logger.log('Setting max. iterations to {}'.format(DEFAULT_ITERATIONS))
setattr(opts, "scheduler.max_iterations", DEFAULT_ITERATIONS)
max_iter = DEFAULT_ITERATIONS
setattr(opts, "scheduler.max_epochs", DEFAULT_MAX_EPOCHS)
if is_master_node:
logger.log('Max. iteration for training: {}'.format(max_iter))
else:
max_epochs = getattr(opts, "scheduler.max_epochs", DEFAULT_EPOCHS)
if max_epochs is None or max_epochs <= 0:
logger.log('Setting max. epochs to {}'.format(DEFAULT_EPOCHS))
setattr(opts, "scheduler.max_epochs", DEFAULT_EPOCHS)
setattr(opts, "scheduler.max_iterations", DEFAULT_MAX_ITERATIONS)
max_epochs = getattr(opts, "scheduler.max_epochs", DEFAULT_EPOCHS)
if is_master_node:
logger.log('Max. epochs for training: {}'.format(max_epochs))
# set-up the model
model = get_model(opts)
if num_gpus == 0:
logger.error('Need atleast 1 GPU for training. Got {} GPUs'.format(num_gpus))
elif num_gpus == 1:
model = model.to(device=device)
elif is_distributed:
model = model.to(device=device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
if is_master_node:
logger.log('Using DistributedDataParallel for training')
else:
model = torch.nn.DataParallel(model)
model = model.to(device=device)
if is_master_node:
logger.log('Using DataParallel for training')
# setup criteria
criteria = build_loss_fn(opts)
criteria = criteria.to(device=device)
# create the optimizer
optimizer = build_optimizer(model, opts=opts)
# create the gradient scalar
gradient_scalar = GradScaler(
enabled=getattr(opts, "common.mixed_precision", False)
)
# LR scheduler
scheduler = build_scheduler(opts=opts)
model_ema = None
use_ema = getattr(opts, "ema.enable", False)
if use_ema:
ema_momentum = getattr(opts, "ema.momentum", 0.0001)
model_ema = EMA(
model=model,
ema_momentum=ema_momentum,
device=device
)
if is_master_node:
logger.log('Using EMA')
best_metric = 0.0 if getattr(opts, "stats.checkpoint_metric_max", False) else math.inf
start_epoch = 0
start_iteration = 0
resume_loc = getattr(opts, "common.resume", None)
finetune_loc = getattr(opts, "common.finetune", None)
auto_resume = getattr(opts, "common.auto_resume", False)
if resume_loc is not None or auto_resume:
model, optimizer, gradient_scalar, start_epoch, start_iteration, best_metric, model_ema = load_checkpoint(
opts=opts,
model=model,
optimizer=optimizer,
model_ema=model_ema,
gradient_scalar=gradient_scalar
)
elif finetune_loc is not None:
model, model_ema = load_model_state(opts=opts, model=model, model_ema=model_ema)
if is_master_node:
logger.log('Finetuning model from checkpoint {}'.format(finetune_loc))
training_engine = Trainer(opts=opts,
model=model,
validation_loader=val_loader,
training_loader=train_loader,
optimizer=optimizer,
criterion=criteria,
scheduler=scheduler,
start_epoch=start_epoch,
start_iteration=start_iteration,
best_metric=best_metric,
model_ema=model_ema,
gradient_scalar=gradient_scalar
)
training_engine.run(train_sampler=train_sampler)