def main()

in cli/jobs/pipelines-with-components/image_classification_with_densenet/image_cnn_train/main.py [0:0]


def main(gpu_index, args):
    if args.log_redirect:
        sys.stdout = open(
            "./outputs_"
            + str(args.rank * args.ngpus_per_node + gpu_index)
            + str(time.time()),
            "w",
        )

    exp_start_time = time.time()
    global best_prec1
    best_prec1 = 0

    args.distributed = False

    args.gpu = 0

    args.local_rank = gpu_index
    args.distributed = args.world_size > 1
    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        print("using gpu ", args.gpu)
        torch.cuda.set_device(args.gpu)

        args.rank = args.rank * args.ngpus_per_node + gpu_index
        dist.init_process_group(
            backend="nccl",
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank,
        )

    if args.amp and args.fp16:
        print("Please use only one of the --fp16/--amp flags")
        exit(1)

    if args.seed is not None:
        print("Using seed = {}".format(args.seed))
        torch.manual_seed(args.seed + args.local_rank)
        torch.cuda.manual_seed(args.seed + args.local_rank)
        np.random.seed(seed=args.seed + args.local_rank)
        random.seed(args.seed + args.local_rank)

        def _worker_init_fn(id):
            np.random.seed(seed=args.seed + args.local_rank + id)
            random.seed(args.seed + args.local_rank + id)

    else:

        def _worker_init_fn(id):
            pass

    if args.fp16:
        assert (
            torch.backends.cudnn.enabled
        ), "fp16 mode requires cudnn backend to be enabled."

    if args.static_loss_scale != 1.0:
        if not args.fp16:
            print("Warning:  if --fp16 is not used, static_loss_scale will be ignored.")

    if args.optimizer_batch_size < 0:
        batch_size_multiplier = 1
    else:
        tbs = args.world_size * args.batch_size
        if args.optimizer_batch_size % tbs != 0:
            print(
                "Warning: simulated batch size {} is not divisible by actual batch size {}".format(
                    args.optimizer_batch_size, tbs
                )
            )
        batch_size_multiplier = int(round(args.optimizer_batch_size / tbs))
        print("BSM: {}".format(batch_size_multiplier))
        print("Real effective batch size is: ", batch_size_multiplier * tbs)

    pretrained_weights = None
    if args.pretrained_weights:
        if os.path.isfile(args.pretrained_weights):
            print(
                "=> loading pretrained weights from '{}'".format(
                    args.pretrained_weights
                )
            )
            pretrained_weights = torch.load(args.pretrained_weights)
        else:
            print("=> no pretrained weights found at '{}'".format(args.resume))

    start_epoch = 0
    args.total_train_step = 0
    # check previous saved checkpoint first
    # if there is none, then resume from user specified checkpoint if there is
    target_ckpt_path = args.workspace + "/" + checkpoint_file_name
    ckpt_path = target_ckpt_path
    if not os.path.isfile(ckpt_path):
        print("=> no checkpoint found at '{}'".format(ckpt_path))
        ckpt_path = args.resume

    # optionally resume from a checkpoint
    if ckpt_path:
        if os.path.isfile(ckpt_path):
            print("=> loading checkpoint '{}'".format(ckpt_path))
            checkpoint = torch.load(
                ckpt_path, map_location=lambda storage, loc: storage.cuda(args.gpu)
            )
            start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model_state = checkpoint["state_dict"]
            optimizer_state = checkpoint["optimizer"]
            args.total_train_step = checkpoint["total_train_step"]
            print(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    ckpt_path, checkpoint["epoch"]
                )
            )
        else:
            print("=> no checkpoint found at '{}'".format(ckpt_path))
            model_state = None
            optimizer_state = None
    else:
        model_state = None
        optimizer_state = None

    loss = nn.CrossEntropyLoss
    if args.mixup > 0.0:
        loss = lambda: NLLMultiLabelSmooth(args.label_smoothing)
    elif args.label_smoothing > 0.0:
        loss = lambda: LabelSmoothing(args.label_smoothing)

    model_and_loss = ModelAndLoss(
        (args.arch, args.model_config),
        loss,
        pretrained_weights=pretrained_weights,
        cuda=True,
        fp16=args.fp16,
    )

    # Create data loaders and optimizers as needed
    if args.data_backend == "pytorch":
        get_train_loader = get_pytorch_train_loader
        get_val_loader = get_pytorch_val_loader
    elif args.data_backend == "dali-gpu":
        get_train_loader = get_dali_train_loader(dali_cpu=False)
        get_val_loader = get_dali_val_loader()
    elif args.data_backend == "dali-cpu":
        get_train_loader = get_dali_train_loader(dali_cpu=True)
        get_val_loader = get_dali_val_loader()
    elif args.data_backend == "syntetic":
        get_val_loader = get_syntetic_loader
        get_train_loader = get_syntetic_loader

    train_loader, train_loader_len = get_train_loader(
        args.data,
        args.batch_size,
        1000,
        args.mixup > 0.0,
        workers=args.workers,
        fp16=args.fp16,
    )
    if args.mixup != 0.0:
        train_loader = MixUpWrapper(args.mixup, 1000, train_loader)

    val_loader, val_loader_len = get_val_loader(
        args.data, args.batch_size, 1000, False, workers=args.workers, fp16=args.fp16
    )

    optimizer = get_optimizer(
        list(model_and_loss.model.named_parameters()),
        args.fp16,
        args.lr,
        args.momentum,
        args.weight_decay,
        nesterov=args.nesterov,
        bn_weight_decay=args.bn_weight_decay,
        state=optimizer_state,
        static_loss_scale=args.static_loss_scale,
        dynamic_loss_scale=args.dynamic_loss_scale,
    )

    if args.lr_schedule == "step":
        lr_policy = lr_step_policy(args.lr, [30, 60, 80], 0.1, args.warmup, logger=None)
    elif args.lr_schedule == "cosine":
        lr_policy = lr_cosine_policy(args.lr, args.warmup, args.epochs, logger=None)
    elif args.lr_schedule == "linear":
        lr_policy = lr_linear_policy(args.lr, args.warmup, args.epochs, logger=None)

    if args.amp:
        model_and_loss, optimizer = amp.initialize(
            model_and_loss,
            optimizer,
            opt_level="O2",
            loss_scale="dynamic" if args.dynamic_loss_scale else args.static_loss_scale,
        )

    if args.distributed:
        model_and_loss.distributed()

    model_and_loss.load_model_state(model_state)

    train_loop(
        model_and_loss,
        optimizer,
        lr_policy,
        train_loader,
        val_loader,
        args.epochs,
        args.fp16,
        None,
        should_backup_checkpoint(args),
        args.save_checkpoint_epochs,
        use_amp=args.amp,
        batch_size_multiplier=batch_size_multiplier,
        start_epoch=start_epoch,
        best_prec1=best_prec1,
        prof=args.prof,
        skip_training=args.evaluate,
        skip_validation=args.training_only,
        save_checkpoints=args.save_checkpoints and not args.evaluate,
        checkpoint_dir=args.workspace,
        total_train_step=args.total_train_step,
    )
    exp_duration = time.time() - exp_start_time

    print("Experiment ended")

    sys.stdout.flush()