def train()

in archived/gluoncv_yolo_neo/train_yolo.py [0:0]


def train(net, train_data, val_data, eval_metric, ctx, args):

    import gluoncv as gcv

    gcv.utils.check_version("0.6.0")
    from gluoncv import data as gdata
    from gluoncv import utils as gutils
    from gluoncv.data.batchify import Pad, Stack, Tuple
    from gluoncv.data.dataloader import RandomTransformDataLoader
    from gluoncv.data.transforms.presets.yolo import (
        YOLO3DefaultTrainTransform,
        YOLO3DefaultValTransform,
    )
    from gluoncv.model_zoo import get_model
    from gluoncv.utils import LRScheduler, LRSequential
    from gluoncv.utils.metrics.coco_detection import COCODetectionMetric
    from gluoncv.utils.metrics.voc_detection import VOC07MApMetric

    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params(".*beta|.*gamma|.*bias").items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(",")]
    lr_decay_epoch = [e - args.warmup_epochs for e in lr_decay_epoch]
    num_batches = args.num_samples // args.batch_size
    lr_scheduler = LRSequential(
        [
            LRScheduler(
                "linear",
                base_lr=0,
                target_lr=args.lr,
                nepochs=args.warmup_epochs,
                iters_per_epoch=num_batches,
            ),
            LRScheduler(
                args.lr_mode,
                base_lr=args.lr,
                nepochs=args.epochs - args.warmup_epochs,
                iters_per_epoch=num_batches,
                step_epoch=lr_decay_epoch,
                step_factor=args.lr_decay,
                power=2,
            ),
        ]
    )

    if args.horovod:
        hvd.broadcast_parameters(net.collect_params(), root_rank=0)
        trainer = hvd.DistributedTrainer(
            net.collect_params(),
            "sgd",
            {"wd": args.wd, "momentum": args.momentum, "lr_scheduler": lr_scheduler},
        )
    else:
        trainer = gluon.Trainer(
            net.collect_params(),
            "sgd",
            {"wd": args.wd, "momentum": args.momentum, "lr_scheduler": lr_scheduler},
            kvstore="local",
            update_on_kvstore=(False if args.amp else None),
        )

    if args.amp:
        amp.init_trainer(trainer)

    # targets
    sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    l1_loss = gluon.loss.L1Loss()

    # metrics
    obj_metrics = mx.metric.Loss("ObjLoss")
    center_metrics = mx.metric.Loss("BoxCenterLoss")
    scale_metrics = mx.metric.Loss("BoxScaleLoss")
    cls_metrics = mx.metric.Loss("ClassLoss")

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = args.save_prefix + "_train.log"
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)
    logger.info(args)
    logger.info("Start training from [Epoch {}]".format(args.start_epoch))
    best_map = [0]
    for epoch in range(args.start_epoch, args.num_epochs):
        if args.mixup:
            # TODO(zhreshold): more elegant way to control mixup during runtime
            try:
                train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
            except AttributeError:
                train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
            if epoch >= args.num_epochs - args.no_mixup_epochs:
                try:
                    train_data._dataset.set_mixup(None)
                except AttributeError:
                    train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [
                gluon.utils.split_and_load(batch[it], ctx_list=ctx, batch_axis=0)
                for it in range(1, 6)
            ]
            gt_boxes = gluon.utils.split_and_load(batch[6], ctx_list=ctx, batch_axis=0)
            sum_losses = []
            obj_losses = []
            center_losses = []
            scale_losses = []
            cls_losses = []
            with autograd.record():
                for ix, x in enumerate(data):
                    obj_loss, center_loss, scale_loss, cls_loss = net(
                        x, gt_boxes[ix], *[ft[ix] for ft in fixed_targets]
                    )
                    sum_losses.append(obj_loss + center_loss + scale_loss + cls_loss)
                    obj_losses.append(obj_loss)
                    center_losses.append(center_loss)
                    scale_losses.append(scale_loss)
                    cls_losses.append(cls_loss)
                if args.amp:
                    with amp.scale_loss(sum_losses, trainer) as scaled_loss:
                        autograd.backward(scaled_loss)
                else:
                    autograd.backward(sum_losses)
            trainer.step(batch_size)
            if not args.horovod or hvd.rank() == 0:
                obj_metrics.update(0, obj_losses)
                center_metrics.update(0, center_losses)
                scale_metrics.update(0, scale_losses)
                cls_metrics.update(0, cls_losses)
                if args.log_interval and not (i + 1) % args.log_interval:
                    name1, loss1 = obj_metrics.get()
                    name2, loss2 = center_metrics.get()
                    name3, loss3 = scale_metrics.get()
                    name4, loss4 = cls_metrics.get()
                    logger.info(
                        "[Epoch {}][Batch {}], LR: {:.2E}, Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}".format(
                            epoch,
                            i,
                            trainer.learning_rate,
                            args.batch_size / (time.time() - btic),
                            name1,
                            loss1,
                            name2,
                            loss2,
                            name3,
                            loss3,
                            name4,
                            loss4,
                        )
                    )
                btic = time.time()

        if not args.horovod or hvd.rank() == 0:
            name1, loss1 = obj_metrics.get()
            name2, loss2 = center_metrics.get()
            name3, loss3 = scale_metrics.get()
            name4, loss4 = cls_metrics.get()
            logger.info(
                "[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}".format(
                    epoch,
                    (time.time() - tic),
                    name1,
                    loss1,
                    name2,
                    loss2,
                    name3,
                    loss3,
                    name4,
                    loss4,
                )
            )
            if not (epoch + 1) % args.val_interval:
                # consider reduce the frequency of validation to save time
                map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
                val_msg = "\n".join(["{}={}".format(k, v) for k, v in zip(map_name, mean_ap)])
                logger.info("[Epoch {}] Validation: \n{}".format(epoch, val_msg))
                current_map = float(mean_ap[-1])
            else:
                current_map = 0.0
            save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)

    # save model
    net.set_nms(nms_thresh=0.45, nms_topk=400, post_nms=100)
    net(mx.nd.ones((1, 3, args.data_shape, args.data_shape), ctx=ctx[0]))
    net.export("%s/model" % os.environ["SM_MODEL_DIR"])