def main()

in models/vision/classification/train_backbone.py [0:0]


def main():
    hvd.init()
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
    tf.config.threading.intra_op_parallelism_threads = 1 # Avoid pool of Eigen threads
    tf.config.threading.inter_op_parallelism_threads = max(2, 40//hvd.size()-2)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

    cmdline = add_cli_args()
    FLAGS, unknown_args = cmdline.parse_known_args()

    if FLAGS.fine_tune:
        raise NotImplementedError('fine tuning functionality not available')

    if not FLAGS.xla_off:
        tf.config.optimizer.set_jit(True)
    if not FLAGS.fp32:
        tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

    preprocessing_type = 'resnet'
    if FLAGS.model == 'resnet50v1_b':
        model = resnet.ResNet50V1_b(weights=None, weight_decay=FLAGS.l2_weight_decay, classes=FLAGS.num_classes)
    elif FLAGS.model == 'resnet50v1_c':
        model = resnet.ResNet50V1_c(weights=None, weight_decay=FLAGS.l2_weight_decay, classes=FLAGS.num_classes)
    elif FLAGS.model == 'resnet50v1_d':
        model = resnet.ResNet50V1_d(weights=None, weight_decay=FLAGS.l2_weight_decay, classes=FLAGS.num_classes)
    elif FLAGS.model == 'resnet101v1_b':
        model = resnet.ResNet101V1_b(weights=None, weight_decay=FLAGS.l2_weight_decay, classes=FLAGS.num_classes)
    elif FLAGS.model == 'resnet101v1_c':
        model = resnet.ResNet101V1_c(weights=None, weight_decay=FLAGS.l2_weight_decay, classes=FLAGS.num_classes)
    elif FLAGS.model == 'resnet101v1_d':
        model = resnet.ResNet101V1_d(weights=None, weight_decay=FLAGS.l2_weight_decay, classes=FLAGS.num_classes)
    elif FLAGS.model == 'darknet53':
        model = darknet.Darknet(weight_decay=FLAGS.l2_weight_decay)
    elif FLAGS.model in ['hrnet_w18c', 'hrnet_w32c']:
        preprocessing_type = 'imagenet'
        model = hrnet.build_hrnet(FLAGS.model)
        model._set_inputs(tf.keras.Input(shape=(None, None, 3)))
    else:
        raise NotImplementedError('Model {} not implemented'.format(FLAGS.model))

    model.summary()

    # scale learning rate linearly, base learning rate for batch size of 256 is specified through args
    BASE_LR = FLAGS.learning_rate
    learning_rate = (BASE_LR * hvd.size() * FLAGS.batch_size)/256 
    steps_per_epoch = int((FLAGS.train_dataset_size / (FLAGS.batch_size * hvd.size())))

    # 5 epochs are for warmup
    if FLAGS.schedule == 'piecewise_short':
        scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
                    boundaries=[steps_per_epoch * 25, steps_per_epoch * 55, steps_per_epoch * 75, step_per_epoch * 100], 
                    values=[learning_rate, learning_rate * 0.1, learning_rate * 0.01, learning_rate * 0.001, learning_rate * 0.0001])
    elif FLAGS.schedule == 'piecewise_long':
        scheduler = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
                    boundaries=[steps_per_epoch * 55, steps_per_epoch * 115, steps_per_epoch * 175], 
                    values=[learning_rate, learning_rate * 0.1, learning_rate * 0.01, learning_rate * 0.001])
    elif FLAGS.schedule == 'cosine':
        scheduler = tf.keras.experimental.CosineDecayRestarts(initial_learning_rate=learning_rate,
                    first_decay_steps=FLAGS.num_epochs*steps_per_epoch, t_mul=1, m_mul=1)
    else:
        print('No schedule specified')


    scheduler = WarmupScheduler(optimizer=scheduler, initial_learning_rate=learning_rate / hvd.size(), warmup_steps=steps_per_epoch * 5)

    #TODO support optimizers choice via config
    # opt = tf.keras.optimizers.SGD(learning_rate=scheduler, momentum=FLAGS.momentum, nesterov=True) # needs momentum correction term
    opt = MomentumOptimizer(learning_rate=scheduler, momentum=FLAGS.momentum, nesterov=True) 

    if not FLAGS.fp32:
        opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt, loss_scale=128.)


    loss_func = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=FLAGS.label_smoothing, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE) 

    if hvd.rank() == 0:
        if FLAGS.resume_from:
            model = tf.keras.models.load_model(FLAGS.resume_from)
            print('loaded model from', FLAGS.resume_from)
        model_dir = os.path.join(FLAGS.model + datetime.datetime.now().strftime("_%Y-%m-%d_%H-%M-%S"))
        path_logs = os.path.join(os.getcwd(), model_dir, 'log.csv')
        os.mkdir(model_dir)
        logging.basicConfig(filename=path_logs,
                            filemode='a',
                            format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                            datefmt='%H:%M:%S',
                            level=logging.DEBUG)
        logging.info("Training Logs")
        logger = logging.getLogger('logger')
        logger.info('Training options: %s', FLAGS)
    
    # barrier
    hvd.allreduce(tf.constant(0))
 
    start_time = time()
    curr_step = tf.Variable(initial_value=0, dtype=tf.int32)
    best_validation_accuracy = 0.7 # only save 0.7 or higher checkpoints
    
    data = create_dataset(FLAGS.train_data_dir, FLAGS.batch_size, preprocessing=preprocessing_type, validation=False)
    validation_data = create_dataset(FLAGS.validation_data_dir, FLAGS.batch_size, preprocessing=preprocessing_type, validation=True)

    for epoch in range(FLAGS.num_epochs):
        if hvd.rank() == 0:
            print('Starting training Epoch %d/%d' % (epoch, FLAGS.num_epochs))
        training_score = 0
        for batch, (images, labels) in enumerate(tqdm(data)):
            # momentum correction (V2 SGD absorbs LR into the update term)
            # prev_lr = opt._optimizer.learning_rate(curr_step-1)
            # curr_lr = opt._optimizer.learning_rate(curr_step)
            # momentum_correction_factor = curr_lr / prev_lr
            # opt._optimizer.momentum = opt._optimizer.momentum * momentum_correction_factor
            loss, score = train_step(model, opt, loss_func, images, labels, batch==0 and epoch==0,
                            batch_size=FLAGS.batch_size, mixup_alpha=FLAGS.mixup_alpha, fp32=FLAGS.fp32)
            # # restore momentum
            # opt._optimizer.momentum = FLAGS.momentum
            training_score += score.numpy()
            curr_step.assign_add(1)
        training_accuracy = training_score / (FLAGS.batch_size * (batch + 1))
        average_training_accuracy = hvd.allreduce(tf.constant(training_accuracy))
        average_training_loss = hvd.allreduce(tf.constant(loss))
        if hvd.rank() == 0:
            print('Starting validation Epoch %d/%d' % (epoch, FLAGS.num_epochs))
        validation_score = 0
        counter = 0
        for images, labels in tqdm(validation_data):
            loss, score = validation_step(images, labels, model, loss_func)
            validation_score += score.numpy()
            counter += 1
        validation_accuracy = validation_score / (FLAGS.batch_size * counter)
        average_validation_accuracy = hvd.allreduce(tf.constant(validation_accuracy))
        average_validation_loss = hvd.allreduce(tf.constant(loss))

        if hvd.rank() == 0:
            info_str = 'Epoch: %d, Train Accuracy: %f, Train Loss: %f, Validation Accuracy: %f, Validation Loss: %f LR:%f' % (
                    epoch, average_training_accuracy, average_training_loss, average_validation_accuracy, average_validation_loss, scheduler(curr_step))
            print(info_str)
            logger.info(info_str)
            if average_validation_accuracy > best_validation_accuracy:
                logger.info("Found new best accuracy, saving checkpoint ...")
                best_validation_accuracy = average_validation_accuracy
                model.save('{}/{}'.format(FLAGS.model_dir, FLAGS.model))
    if hvd.rank() == 0:
        logger.info('Total Training Time: %f' % (time() - start_time))