def main_ec2()

in models/vision/detection/tools/train.py [0:0]


def main_ec2(args, cfg):
    # start logger
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    log_file = osp.join(cfg.work_dir, '{}.log'.format(timestamp))
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
    """
    Main training entry point for jobs launched directly on EC2 instances
    """
    num_gpus = len(gpus)
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    if args.resume_dir is not None:
        if os.path.exists(args.resume_dir):
            logger.info("RESUMING TRAINING")
            # get the latest checkpoint
            all_chkpt = [os.path.join(args.resume_dir,d) for d in os.listdir(args.resume_dir) if os.path.isdir(os.path.join(args.resume_dir,d))]
            if not all_chkpt:
               cfg.resume_from = None
            else: 
               latest_chkpt = max(all_chkpt, key=os.path.getmtime)
               # set the latest checkpoint to resume_from
               cfg.resume_from = latest_chkpt
        else:
            logger.info("CHECKPOINT NOT FOUND, RESTARTING TRAINING")
            cfg.resume_from = None

    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        total_bs = get_dist_info()[2] * cfg.data.imgs_per_gpu
        cfg.optimizer['learning_rate'] = cfg.optimizer['learning_rate'] * total_bs / 8

    if not gpus:
        distributed = False  # single node single gpu
    else:
        distributed = True

    # create work_dir
    mkdir_or_exist(osp.abspath(cfg.work_dir))
    # log some basic info
    logger.info('Distributed training: {}'.format(distributed))
    logger.info('AWSDet Version: {}'.format(__version__))
    logger.info('Config:\n{}'.format(cfg.text))
    logger.info('Tensorflow version: {}'.format(tf.version.VERSION))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}, deterministic: {}'.format(
            args.seed, args.deterministic))
        set_random_seed(args.seed + get_dist_info()[0], deterministic=args.deterministic)

    model = build_detector(cfg.model,
                           train_cfg=cfg.train_cfg,
                           test_cfg=cfg.test_cfg)

    # dummy data to init network
    padded_img_side = max(cfg.data.train['scale'])
    img = tf.random.uniform(shape=[padded_img_side, padded_img_side, 3], dtype=tf.float32)
    img_meta = tf.constant(
        [465., 640., 3., 800., 1101., 3., float(padded_img_side), float(padded_img_side), 3., 1.7204301, 0.],
        dtype=tf.float32)
    _ = model((tf.expand_dims(img, axis=0), tf.expand_dims(img_meta, axis=0)), training=False)

    # print('BEFORE:', model.layers[0].layers[0].get_weights()[0][0,0,0,:])
    weights_path = cfg.model['backbone']['weights_path']
    logger.info('Loading weights from: {}'.format(weights_path))
    if osp.splitext(weights_path)[1] == '.h5': # older keras format from Keras model zoo
        model.layers[0].layers[0].load_weights(weights_path, by_name=True, skip_mismatch=True)
    else: # SavedModel format assumed - extract weights
        backbone_model = tf.keras.models.load_model(weights_path)
        target_backbone_model = model.layers[0].layers[0]
        # load weights if layers match
        for layer in backbone_model.layers:
            # search for target layer
            for target_layer in target_backbone_model.layers:
                if layer.name == target_layer.name:
                    target_layer.set_weights(layer.get_weights())
                    # print('Loaded weights for:', layer.name)
        del backbone_model
    # print('AFTER:',model.layers[0].layers[0].get_weights()[0][0,0,0,:])

    patterns = cfg.train_cfg.get('freeze_patterns', None)
    if patterns:
        freeze_model_layers(model, patterns)

    print_model_info(model, logger)

    datasets = [build_dataset(cfg.data.train)]
    if len(cfg.workflow) == 2:
        datasets.append(build_dataset(cfg.data.val))

    datasets = [build_dataset(cfg.data.train)]

    if len(cfg.workflow) > 1:
        raise NotImplementedError

    train_detector(model,
                   datasets,
                   cfg,
                   num_gpus=num_gpus,
                   distributed=distributed,
                   mixed_precision=args.amp,
                   validate=args.validate,
                   timestamp=timestamp)