def cnn_model_function()

in legacy/models/resnet/tensorflow/train_imagenet_resnet_hvd.py [0:0]


def cnn_model_function(features, labels, mode, params):
    labels = tf.reshape(labels, (-1,))  # Squash unnecessary unary dim
    lr = params['lr']
    lr_steps = params['lr_steps']
    steps = params['steps']
    use_larc = params['use_larc']
    leta = params['leta']
    lr_decay_mode = params['lr_decay_mode']
    decay_steps = params['decay_steps']
    cdr_first_decay_ratio = params['cdr_first_decay_ratio']
    cdr_t_mul = params['cdr_t_mul']
    cdr_m_mul = params['cdr_m_mul']
    cdr_alpha = params['cdr_alpha']
    lc_periods = params['lc_periods']
    lc_alpha = params['lc_alpha']
    lc_beta = params['lc_beta']

    model_name = params['model']
    num_classes = params['n_classes']
    model_dtype = get_with_default(params, 'dtype', tf.float32)
    model_format = get_with_default(params, 'format', 'channels_first')
    device = get_with_default(params, 'device', '/gpu:0')
    model_func = get_model_func(model_name)
    inputs = features  # TODO: Should be using feature columns?
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    momentum = params['mom']
    weight_decay = params['wdecay']
    warmup_lr = params['warmup_lr']
    warmup_it = params['warmup_it']
    loss_scale = params['loss_scale']

    adv_bn_init = params['adv_bn_init']
    conv_init = params['conv_init']

    if mode == tf.estimator.ModeKeys.TRAIN:
        with tf.device('/cpu:0'):
            preload_op, (inputs, labels) = stage([inputs, labels])

    with tf.device(device):
        if mode == tf.estimator.ModeKeys.TRAIN:
            gpucopy_op, (inputs, labels) = stage([inputs, labels])
        inputs = tf.cast(inputs, model_dtype)
        imagenet_mean = np.array([121, 115, 100], dtype=np.float32)
        imagenet_std = np.array([70, 68, 71], dtype=np.float32)
        inputs = tf.subtract(inputs, imagenet_mean)
        inputs = tf.multiply(inputs, 1. / imagenet_std)
        if model_format == 'channels_first':
            inputs = tf.transpose(inputs, [0, 3, 1, 2])
        with fp32_trainable_vars(
                regularizer=tf.contrib.layers.l2_regularizer(weight_decay)):
            top_layer = model_func(
                inputs, data_format=model_format, training=is_training,
                conv_initializer=conv_init, adv_bn_init=adv_bn_init)
            logits = tf.layers.dense(top_layer, num_classes,
                                     kernel_initializer=tf.random_normal_initializer(stddev=0.01))
        predicted_classes = tf.argmax(logits, axis=1, output_type=tf.int32)
        logits = tf.cast(logits, tf.float32)
        if mode == tf.estimator.ModeKeys.PREDICT:
            probabilities = tf.softmax(logits)
            predictions = {
                'class_ids': predicted_classes[:, None],
                'probabilities': probabilities,
                'logits': logits
            }
            return tf.estimator.EstimatorSpec(mode, predictions=predictions)
        loss = tf.losses.sparse_softmax_cross_entropy(
            logits=logits, labels=labels)
        loss = tf.identity(loss, name='loss')  # For access by logger (TODO: Better way to access it?)

        if mode == tf.estimator.ModeKeys.EVAL:
            with tf.device(None):  # Allow fallback to CPU if no GPU support for these ops
                accuracy = tf.metrics.accuracy(
                    labels=labels, predictions=predicted_classes)
                top5acc = tf.metrics.mean(
                    tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32))
                newaccuracy = (hvd.allreduce(accuracy[0]), accuracy[1])
                newtop5acc = (hvd.allreduce(top5acc[0]), top5acc[1])
                metrics = {'val-top1acc': newaccuracy, 'val-top5acc': newtop5acc}
            return tf.estimator.EstimatorSpec(
                mode, loss=loss, eval_metric_ops=metrics)

        assert (mode == tf.estimator.ModeKeys.TRAIN)
        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        total_loss = tf.add_n([loss] + reg_losses, name='total_loss')

        batch_size = tf.shape(inputs)[0]

        global_step = tf.train.get_global_step()

        with tf.device('/cpu:0'):  # Allow fallback to CPU if no GPU support for these ops
            learning_rate = tf.cond(global_step < warmup_it,
                                    lambda: warmup_decay(warmup_lr, global_step, warmup_it,
                                                         lr),
                                    lambda: get_lr(lr, steps, lr_steps, warmup_it, decay_steps, global_step,
                                                   lr_decay_mode, 
                                                   cdr_first_decay_ratio, cdr_t_mul, cdr_m_mul, cdr_alpha, 
                                                   lc_periods, lc_alpha, lc_beta))
            learning_rate = tf.identity(learning_rate, 'learning_rate')
            tf.summary.scalar('learning_rate', learning_rate)

        opt = tf.train.MomentumOptimizer(
            learning_rate, momentum, use_nesterov=True)
        opt = hvd.DistributedOptimizer(opt)
        if use_larc:
            opt = LarcOptimizer(opt, learning_rate, leta, clip=True)
        opt = MixedPrecisionOptimizer(opt, scale=loss_scale)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) or []
        with tf.control_dependencies(update_ops):
            gate_gradients = (tf.train.Optimizer.GATE_NONE)
            train_op = opt.minimize(
                total_loss, global_step=tf.train.get_global_step(),
                gate_gradients=gate_gradients)
        train_op = tf.group(preload_op, gpucopy_op, train_op)  # , update_ops)

        return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op)