in legacy/models/resnet/tensorflow/train_imagenet_resnet_hvd.py [0:0]
def cnn_model_function(features, labels, mode, params):
labels = tf.reshape(labels, (-1,)) # Squash unnecessary unary dim
lr = params['lr']
lr_steps = params['lr_steps']
steps = params['steps']
use_larc = params['use_larc']
leta = params['leta']
lr_decay_mode = params['lr_decay_mode']
decay_steps = params['decay_steps']
cdr_first_decay_ratio = params['cdr_first_decay_ratio']
cdr_t_mul = params['cdr_t_mul']
cdr_m_mul = params['cdr_m_mul']
cdr_alpha = params['cdr_alpha']
lc_periods = params['lc_periods']
lc_alpha = params['lc_alpha']
lc_beta = params['lc_beta']
model_name = params['model']
num_classes = params['n_classes']
model_dtype = get_with_default(params, 'dtype', tf.float32)
model_format = get_with_default(params, 'format', 'channels_first')
device = get_with_default(params, 'device', '/gpu:0')
model_func = get_model_func(model_name)
inputs = features # TODO: Should be using feature columns?
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
momentum = params['mom']
weight_decay = params['wdecay']
warmup_lr = params['warmup_lr']
warmup_it = params['warmup_it']
loss_scale = params['loss_scale']
adv_bn_init = params['adv_bn_init']
conv_init = params['conv_init']
if mode == tf.estimator.ModeKeys.TRAIN:
with tf.device('/cpu:0'):
preload_op, (inputs, labels) = stage([inputs, labels])
with tf.device(device):
if mode == tf.estimator.ModeKeys.TRAIN:
gpucopy_op, (inputs, labels) = stage([inputs, labels])
inputs = tf.cast(inputs, model_dtype)
imagenet_mean = np.array([121, 115, 100], dtype=np.float32)
imagenet_std = np.array([70, 68, 71], dtype=np.float32)
inputs = tf.subtract(inputs, imagenet_mean)
inputs = tf.multiply(inputs, 1. / imagenet_std)
if model_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 3, 1, 2])
with fp32_trainable_vars(
regularizer=tf.contrib.layers.l2_regularizer(weight_decay)):
top_layer = model_func(
inputs, data_format=model_format, training=is_training,
conv_initializer=conv_init, adv_bn_init=adv_bn_init)
logits = tf.layers.dense(top_layer, num_classes,
kernel_initializer=tf.random_normal_initializer(stddev=0.01))
predicted_classes = tf.argmax(logits, axis=1, output_type=tf.int32)
logits = tf.cast(logits, tf.float32)
if mode == tf.estimator.ModeKeys.PREDICT:
probabilities = tf.softmax(logits)
predictions = {
'class_ids': predicted_classes[:, None],
'probabilities': probabilities,
'logits': logits
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
loss = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=labels)
loss = tf.identity(loss, name='loss') # For access by logger (TODO: Better way to access it?)
if mode == tf.estimator.ModeKeys.EVAL:
with tf.device(None): # Allow fallback to CPU if no GPU support for these ops
accuracy = tf.metrics.accuracy(
labels=labels, predictions=predicted_classes)
top5acc = tf.metrics.mean(
tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32))
newaccuracy = (hvd.allreduce(accuracy[0]), accuracy[1])
newtop5acc = (hvd.allreduce(top5acc[0]), top5acc[1])
metrics = {'val-top1acc': newaccuracy, 'val-top5acc': newtop5acc}
return tf.estimator.EstimatorSpec(
mode, loss=loss, eval_metric_ops=metrics)
assert (mode == tf.estimator.ModeKeys.TRAIN)
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
total_loss = tf.add_n([loss] + reg_losses, name='total_loss')
batch_size = tf.shape(inputs)[0]
global_step = tf.train.get_global_step()
with tf.device('/cpu:0'): # Allow fallback to CPU if no GPU support for these ops
learning_rate = tf.cond(global_step < warmup_it,
lambda: warmup_decay(warmup_lr, global_step, warmup_it,
lr),
lambda: get_lr(lr, steps, lr_steps, warmup_it, decay_steps, global_step,
lr_decay_mode,
cdr_first_decay_ratio, cdr_t_mul, cdr_m_mul, cdr_alpha,
lc_periods, lc_alpha, lc_beta))
learning_rate = tf.identity(learning_rate, 'learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
opt = tf.train.MomentumOptimizer(
learning_rate, momentum, use_nesterov=True)
opt = hvd.DistributedOptimizer(opt)
if use_larc:
opt = LarcOptimizer(opt, learning_rate, leta, clip=True)
opt = MixedPrecisionOptimizer(opt, scale=loss_scale)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) or []
with tf.control_dependencies(update_ops):
gate_gradients = (tf.train.Optimizer.GATE_NONE)
train_op = opt.minimize(
total_loss, global_step=tf.train.get_global_step(),
gate_gradients=gate_gradients)
train_op = tf.group(preload_op, gpucopy_op, train_op) # , update_ops)
return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op)