in sagemaker-python-sdk/mxnet_horovod_maskrcnn/source/train_mask_rcnn.py [0:0]
def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args):
"""Training pipeline"""
args.kv_store = "device" if (args.amp and "nccl" in args.kv_store) else args.kv_store
kv = mx.kvstore.create(args.kv_store)
net.collect_params().setattr("grad_req", "null")
net.collect_train_params().setattr("grad_req", "write")
for k, v in net.collect_params(".*bias").items():
v.wd_mult = 0.0
optimizer_params = {
"learning_rate": args.lr,
"wd": args.wd,
"momentum": args.momentum,
}
if args.clip_gradient > 0.0:
optimizer_params["clip_gradient"] = args.clip_gradient
if args.amp:
optimizer_params["multi_precision"] = True
if args.horovod:
hvd.broadcast_parameters(net.collect_params(), root_rank=0)
trainer = hvd.DistributedTrainer(
net.collect_train_params(), # fix batchnorm, fix first stage, etc...
"sgd",
optimizer_params,
)
else:
trainer = gluon.Trainer(
net.collect_train_params(), # fix batchnorm, fix first stage, etc...
"sgd",
optimizer_params,
update_on_kvstore=(False if args.amp else None),
kvstore=kv,
)
if args.amp:
amp.init_trainer(trainer)
# lr decay policy
lr_decay = float(args.lr_decay)
lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(",") if ls.strip()])
lr_warmup = float(args.lr_warmup) # avoid int division
rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
rpn_box_loss = mx.gluon.loss.HuberLoss(1.0 / 9.0) # == smoothl1
rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
rcnn_box_loss = mx.gluon.loss.HuberLoss(1.0) # == smoothl1
rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
metrics = [
mx.metric.Loss("RPN_Conf"),
mx.metric.Loss("RPN_SmoothL1"),
mx.metric.Loss("RCNN_CrossEntropy"),
mx.metric.Loss("RCNN_SmoothL1"),
mx.metric.Loss("RCNN_Mask"),
]
rpn_acc_metric = RPNAccMetric()
rpn_bbox_metric = RPNL1LossMetric()
rcnn_acc_metric = RCNNAccMetric()
rcnn_bbox_metric = RCNNL1LossMetric()
rcnn_mask_metric = MaskAccMetric()
rcnn_fgmask_metric = MaskFGAccMetric()
metrics2 = [
rpn_acc_metric,
rpn_bbox_metric,
rcnn_acc_metric,
rcnn_bbox_metric,
rcnn_mask_metric,
rcnn_fgmask_metric,
]
async_eval_processes = []
logger.info(args)
if args.verbose:
logger.info("Trainable parameters:")
logger.info(net.collect_train_params().keys())
logger.info("Start training from [Epoch {}]".format(args.start_epoch))
best_map = [0]
base_lr = trainer.learning_rate
for epoch in range(args.start_epoch, args.epochs):
rcnn_task = ForwardBackwardTask(
net,
trainer,
rpn_cls_loss,
rpn_box_loss,
rcnn_cls_loss,
rcnn_box_loss,
rcnn_mask_loss,
args.amp,
)
executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
net.hybridize()
while lr_steps and epoch >= lr_steps[0]:
new_lr = trainer.learning_rate * lr_decay
lr_steps.pop(0)
trainer.set_learning_rate(new_lr)
logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
for metric in metrics:
metric.reset()
tic = time.time()
btic = time.time()
train_data_iter = iter(train_data)
next_data_batch = next(train_data_iter)
next_data_batch = split_and_load(next_data_batch, ctx_list=ctx)
for i in range(len(train_data)):
batch = next_data_batch
if i + epoch * len(train_data) <= lr_warmup:
# adjust based on real percentage
new_lr = base_lr * get_lr_at_iter(
i / lr_warmup, args.lr_warmup_factor / args.num_gpus
)
if new_lr != trainer.learning_rate:
if i % args.log_interval == 0:
logger.info(
"[Epoch {} Iteration {}] Set learning rate to {}".format(
epoch, i, new_lr
)
)
trainer.set_learning_rate(new_lr)
metric_losses = [[] for _ in metrics]
add_losses = [[] for _ in metrics2]
if executor is not None:
for data in zip(*batch):
executor.put(data)
for j in range(len(ctx)):
if executor is not None:
result = executor.get()
else:
result = rcnn_task.forward_backward(list(zip(*batch))[0])
if (not args.horovod) or hvd.rank() == 0:
for k in range(len(metric_losses)):
metric_losses[k].append(result[k])
for k in range(len(add_losses)):
add_losses[k].append(result[len(metric_losses) + k])
try:
# prefetch next batch
next_data_batch = next(train_data_iter)
next_data_batch = split_and_load(next_data_batch, ctx_list=ctx)
except StopIteration:
pass
for metric, record in zip(metrics, metric_losses):
metric.update(0, record)
for metric, records in zip(metrics2, add_losses):
for pred in records:
metric.update(pred[0], pred[1])
trainer.step(batch_size)
if (
(not args.horovod or hvd.rank() == 0)
and args.log_interval
and not (i + 1) % args.log_interval
):
msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics + metrics2])
logger.info(
"[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}".format(
epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg
)
)
btic = time.time()
# validate and save params
if (not args.horovod) or hvd.rank() == 0:
msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics])
logger.info(
"[Epoch {}] Training cost: {:.3f}, {}".format(epoch, (time.time() - tic), msg)
)
if not (epoch + 1) % args.val_interval:
# consider reduce the frequency of validation to save time
validate(
net, val_data, async_eval_processes, ctx, eval_metric, logger, epoch, best_map, args
)
elif (not args.horovod) or hvd.rank() == 0:
current_map = 0.0
save_params(
net,
logger,
best_map,
current_map,
epoch,
args.save_interval,
os.path.join(args.sm_save, args.save_prefix),
)
for thread in async_eval_processes:
thread.join()