def train()

in sagemaker-python-sdk/mxnet_horovod_fasterrcnn/source/train_faster_rcnn.py [0:0]


def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args):
    """Training pipeline"""
    args.kv_store = "device" if (args.amp and "nccl" in args.kv_store) else args.kv_store
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr("grad_req", "null")
    net.collect_train_params().setattr("grad_req", "write")
    optimizer_params = {"learning_rate": args.lr, "wd": args.wd, "momentum": args.momentum}
    if args.amp:
        optimizer_params["multi_precision"] = True
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            "sgd",
            optimizer_params,
        )
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            "sgd",
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv,
        )

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(",") if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    # TODO(zhreshold) losses?
    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.0)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(rho=1.0)  # == smoothl1
    metrics = [
        mx.metric.Loss("RPN_Conf"),
        mx.metric.Loss("RPN_SmoothL1"),
        mx.metric.Loss("RCNN_CrossEntropy"),
        mx.metric.Loss("RCNN_SmoothL1"),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]

    logger.info(args)

    if args.verbose:
        logger.info("Trainable parameters:")
        logger.info(net.collect_train_params().keys())
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.epochs):
        rcnn_task = ForwardBackwardTask(
            net,
            trainer,
            rpn_cls_loss,
            rpn_box_loss,
            rcnn_cls_loss,
            rcnn_box_loss,
            mix_ratio=1.0,
            amp_enabled=args.amp,
        )
        executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
        mix_ratio = 1.0
        net.hybridize()

        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        base_lr = trainer.learning_rate
        rcnn_task.mix_ratio = mix_ratio
        for i, batch in enumerate(train_data):
            if epoch == 0 and i <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(
                    i / lr_warmup, args.lr_warmup_factor / args.num_gpus
                )
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            "[Epoch 0 Iteration {}] Set learning rate to {}".format(i, new_lr)
                        )
                    trainer.set_learning_rate(new_lr)
            batch = split_and_load(batch, ctx_list=ctx)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)

            # update metrics
            if (
                (not args.horovod or hvd.rank() == 0)
                and args.log_interval
                and not (i + 1) % args.log_interval
            ):
                msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics + metrics2])
                logger.info(
                    "[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}".format(
                        epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg
                    )
                )
                btic = time.time()

        if (not args.horovod) or hvd.rank() == 0:
            msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics])
            logger.info(
                "[Epoch {}] Training cost: {:.3f}, {}".format(epoch, (time.time() - tic), msg)
            )
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric, args)
                val_msg = "\n".join(["{}={}".format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info("[Epoch {}] Validation: \n{}".format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.0
            save_params(
                net,
                logger,
                best_map,
                current_map,
                epoch,
                args.save_interval,
                os.path.join(args.sm_save, args.save_prefix),
                args,
            )