def cnn_model_function()

in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]


def cnn_model_function(features, labels, mode, params):
    labels = tf.reshape(labels, (-1,))  # Squash unnecessary unary dim
    lr = params["lr"]
    lr_steps = params["lr_steps"]
    steps = params["steps"]
    use_larc = params["use_larc"]
    leta = params["leta"]
    lr_decay_mode = params["lr_decay_mode"]
    decay_steps = params["decay_steps"]
    cdr_first_decay_ratio = params["cdr_first_decay_ratio"]
    cdr_t_mul = params["cdr_t_mul"]
    cdr_m_mul = params["cdr_m_mul"]
    cdr_alpha = params["cdr_alpha"]
    lc_periods = params["lc_periods"]
    lc_alpha = params["lc_alpha"]
    lc_beta = params["lc_beta"]

    model_name = params["model"]
    num_classes = params["n_classes"]
    model_dtype = get_with_default(params, "dtype", tf.float32)
    model_format = get_with_default(params, "format", "channels_first")
    device = get_with_default(params, "device", "/gpu:0")
    model_func = get_model_func(model_name)
    inputs = features  # TODO: Should be using feature columns?
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    momentum = params["mom"]
    weight_decay = params["wdecay"]
    warmup_lr = params["warmup_lr"]
    warmup_it = params["warmup_it"]
    loss_scale = params["loss_scale"]

    adv_bn_init = params["adv_bn_init"]
    conv_init = params["conv_init"]

    if mode == tf.estimator.ModeKeys.TRAIN:
        with tf.device("/cpu:0"):
            preload_op, (inputs, labels) = stage([inputs, labels])

    with tf.device(device):
        if mode == tf.estimator.ModeKeys.TRAIN:
            gpucopy_op, (inputs, labels) = stage([inputs, labels])
        inputs = tf.cast(inputs, model_dtype)
        imagenet_mean = np.array([121, 115, 100], dtype=np.float32)
        imagenet_std = np.array([70, 68, 71], dtype=np.float32)
        inputs = tf.subtract(inputs, imagenet_mean)
        inputs = tf.multiply(inputs, 1.0 / imagenet_std)
        if model_format == "channels_first":
            inputs = tf.transpose(inputs, [0, 3, 1, 2])
        with fp32_trainable_vars(regularizer=tf.contrib.layers.l2_regularizer(weight_decay)):
            top_layer = model_func(
                inputs,
                data_format=model_format,
                training=is_training,
                conv_initializer=conv_init,
                adv_bn_init=adv_bn_init,
            )
            logits = tf.layers.dense(
                top_layer, num_classes, kernel_initializer=tf.random_normal_initializer(stddev=0.01)
            )
        predicted_classes = tf.argmax(logits, axis=1, output_type=tf.int32)
        logits = tf.cast(logits, tf.float32)
        if mode == tf.estimator.ModeKeys.PREDICT:
            probabilities = tf.softmax(logits)
            predictions = {
                "class_ids": predicted_classes[:, None],
                "probabilities": probabilities,
                "logits": logits,
            }
            return tf.estimator.EstimatorSpec(mode, predictions=predictions)
        loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)
        loss = tf.identity(
            loss, name="loss"
        )  # For access by logger (TODO: Better way to access it?)

        if mode == tf.estimator.ModeKeys.EVAL:
            with tf.device(None):  # Allow fallback to CPU if no GPU support for these ops
                accuracy = tf.metrics.accuracy(labels=labels, predictions=predicted_classes)
                top5acc = tf.metrics.mean(tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32))
                newaccuracy = (hvd.allreduce(accuracy[0]), accuracy[1])
                newtop5acc = (hvd.allreduce(top5acc[0]), top5acc[1])
                metrics = {"val-top1acc": newaccuracy, "val-top5acc": newtop5acc}
            return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)

        assert mode == tf.estimator.ModeKeys.TRAIN
        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([loss] + reg_losses, name="total_loss")

        batch_size = tf.shape(inputs)[0]

        global_step = tf.train.get_global_step()

        with tf.device("/cpu:0"):  # Allow fallback to CPU if no GPU support for these ops
            learning_rate = tf.cond(
                global_step < warmup_it,
                lambda: warmup_decay(warmup_lr, global_step, warmup_it, lr),
                lambda: get_lr(
                    lr,
                    steps,
                    lr_steps,
                    warmup_it,
                    decay_steps,
                    global_step,
                    lr_decay_mode,
                    cdr_first_decay_ratio,
                    cdr_t_mul,
                    cdr_m_mul,
                    cdr_alpha,
                    lc_periods,
                    lc_alpha,
                    lc_beta,
                ),
            )
            learning_rate = tf.identity(learning_rate, "learning_rate")
            tf.summary.scalar("learning_rate", learning_rate)

        opt = tf.train.MomentumOptimizer(learning_rate, momentum, use_nesterov=True)
        opt = hvd.DistributedOptimizer(opt)
        if use_larc:
            opt = LarcOptimizer(opt, learning_rate, leta, clip=True)
        opt = MixedPrecisionOptimizer(opt, scale=loss_scale)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) or []
        with tf.control_dependencies(update_ops):
            gate_gradients = tf.train.Optimizer.GATE_NONE
            train_op = opt.minimize(
                total_loss, global_step=tf.train.get_global_step(), gate_gradients=gate_gradients
            )
        train_op = tf.group(preload_op, gpucopy_op, train_op)  # , update_ops)

        return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op)