def train()

in sagemaker-python-sdk/mxnet_horovod_maskrcnn/source/train_mask_rcnn.py [0:0]


def train(net, train_data, val_data, eval_metric, batch_size, ctx, logger, args):
    """Training pipeline"""
    args.kv_store = "device" if (args.amp and "nccl" in args.kv_store) else args.kv_store
    kv = mx.kvstore.create(args.kv_store)
    net.collect_params().setattr("grad_req", "null")
    net.collect_train_params().setattr("grad_req", "write")
    for k, v in net.collect_params(".*bias").items():
        v.wd_mult = 0.0
    optimizer_params = {
        "learning_rate": args.lr,
        "wd": args.wd,
        "momentum": args.momentum,
    }
    if args.clip_gradient > 0.0:
        optimizer_params["clip_gradient"] = args.clip_gradient
    if args.amp:
        optimizer_params["multi_precision"] = True
    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            "sgd",
            optimizer_params,
        )
    else:
        trainer = gluon.Trainer(
            net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
            "sgd",
            optimizer_params,
            update_on_kvstore=(False if args.amp else None),
            kvstore=kv,
        )

    if args.amp:
        amp.init_trainer(trainer)

    # lr decay policy
    lr_decay = float(args.lr_decay)
    lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(",") if ls.strip()])
    lr_warmup = float(args.lr_warmup)  # avoid int division

    rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    rpn_box_loss = mx.gluon.loss.HuberLoss(1.0 / 9.0)  # == smoothl1
    rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
    rcnn_box_loss = mx.gluon.loss.HuberLoss(1.0)  # == smoothl1
    rcnn_mask_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    metrics = [
        mx.metric.Loss("RPN_Conf"),
        mx.metric.Loss("RPN_SmoothL1"),
        mx.metric.Loss("RCNN_CrossEntropy"),
        mx.metric.Loss("RCNN_SmoothL1"),
        mx.metric.Loss("RCNN_Mask"),
    ]

    rpn_acc_metric = RPNAccMetric()
    rpn_bbox_metric = RPNL1LossMetric()
    rcnn_acc_metric = RCNNAccMetric()
    rcnn_bbox_metric = RCNNL1LossMetric()
    rcnn_mask_metric = MaskAccMetric()
    rcnn_fgmask_metric = MaskFGAccMetric()
    metrics2 = [
        rpn_acc_metric,
        rpn_bbox_metric,
        rcnn_acc_metric,
        rcnn_bbox_metric,
        rcnn_mask_metric,
        rcnn_fgmask_metric,
    ]
    async_eval_processes = []
    logger.info(args)

    if args.verbose:
        logger.info("Trainable parameters:")
        logger.info(net.collect_train_params().keys())
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]
    base_lr = trainer.learning_rate
    for epoch in range(args.start_epoch, args.epochs):
        rcnn_task = ForwardBackwardTask(
            net,
            trainer,
            rpn_cls_loss,
            rpn_box_loss,
            rcnn_cls_loss,
            rcnn_box_loss,
            rcnn_mask_loss,
            args.amp,
        )
        executor = Parallel(args.executor_threads, rcnn_task) if not args.horovod else None
        net.hybridize()
        while lr_steps and epoch >= lr_steps[0]:
            new_lr = trainer.learning_rate * lr_decay
            lr_steps.pop(0)
            trainer.set_learning_rate(new_lr)
            logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
        for metric in metrics:
            metric.reset()
        tic = time.time()
        btic = time.time()
        train_data_iter = iter(train_data)
        next_data_batch = next(train_data_iter)
        next_data_batch = split_and_load(next_data_batch, ctx_list=ctx)
        for i in range(len(train_data)):
            batch = next_data_batch
            if i + epoch * len(train_data) <= lr_warmup:
                # adjust based on real percentage
                new_lr = base_lr * get_lr_at_iter(
                    i / lr_warmup, args.lr_warmup_factor / args.num_gpus
                )
                if new_lr != trainer.learning_rate:
                    if i % args.log_interval == 0:
                        logger.info(
                            "[Epoch {} Iteration {}] Set learning rate to {}".format(
                                epoch, i, new_lr
                            )
                        )
                    trainer.set_learning_rate(new_lr)
            metric_losses = [[] for _ in metrics]
            add_losses = [[] for _ in metrics2]
            if executor is not None:
                for data in zip(*batch):
                    executor.put(data)
            for j in range(len(ctx)):
                if executor is not None:
                    result = executor.get()
                else:
                    result = rcnn_task.forward_backward(list(zip(*batch))[0])
                if (not args.horovod) or hvd.rank() == 0:
                    for k in range(len(metric_losses)):
                        metric_losses[k].append(result[k])
                    for k in range(len(add_losses)):
                        add_losses[k].append(result[len(metric_losses) + k])
            try:
                # prefetch next batch
                next_data_batch = next(train_data_iter)
                next_data_batch = split_and_load(next_data_batch, ctx_list=ctx)
            except StopIteration:
                pass

            for metric, record in zip(metrics, metric_losses):
                metric.update(0, record)
            for metric, records in zip(metrics2, add_losses):
                for pred in records:
                    metric.update(pred[0], pred[1])
            trainer.step(batch_size)
            if (
                (not args.horovod or hvd.rank() == 0)
                and args.log_interval
                and not (i + 1) % args.log_interval
            ):
                msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics + metrics2])
                logger.info(
                    "[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}".format(
                        epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg
                    )
                )
                btic = time.time()
        # validate and save params
        if (not args.horovod) or hvd.rank() == 0:
            msg = ",".join(["{}={:.3f}".format(*metric.get()) for metric in metrics])
            logger.info(
                "[Epoch {}] Training cost: {:.3f}, {}".format(epoch, (time.time() - tic), msg)
            )
        if not (epoch + 1) % args.val_interval:
            # consider reduce the frequency of validation to save time
            validate(
                net, val_data, async_eval_processes, ctx, eval_metric, logger, epoch, best_map, args
            )
        elif (not args.horovod) or hvd.rank() == 0:
            current_map = 0.0
            save_params(
                net,
                logger,
                best_map,
                current_map,
                epoch,
                args.save_interval,
                os.path.join(args.sm_save, args.save_prefix),
            )

    for thread in async_eval_processes:
        thread.join()