in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]
def cnn_model_function(features, labels, mode, params):
labels = tf.reshape(labels, (-1,)) # Squash unnecessary unary dim
lr = params["lr"]
lr_steps = params["lr_steps"]
steps = params["steps"]
use_larc = params["use_larc"]
leta = params["leta"]
lr_decay_mode = params["lr_decay_mode"]
decay_steps = params["decay_steps"]
cdr_first_decay_ratio = params["cdr_first_decay_ratio"]
cdr_t_mul = params["cdr_t_mul"]
cdr_m_mul = params["cdr_m_mul"]
cdr_alpha = params["cdr_alpha"]
lc_periods = params["lc_periods"]
lc_alpha = params["lc_alpha"]
lc_beta = params["lc_beta"]
model_name = params["model"]
num_classes = params["n_classes"]
model_dtype = get_with_default(params, "dtype", tf.float32)
model_format = get_with_default(params, "format", "channels_first")
device = get_with_default(params, "device", "/gpu:0")
model_func = get_model_func(model_name)
inputs = features # TODO: Should be using feature columns?
is_training = mode == tf.estimator.ModeKeys.TRAIN
momentum = params["mom"]
weight_decay = params["wdecay"]
warmup_lr = params["warmup_lr"]
warmup_it = params["warmup_it"]
loss_scale = params["loss_scale"]
adv_bn_init = params["adv_bn_init"]
conv_init = params["conv_init"]
if mode == tf.estimator.ModeKeys.TRAIN:
with tf.device("/cpu:0"):
preload_op, (inputs, labels) = stage([inputs, labels])
with tf.device(device):
if mode == tf.estimator.ModeKeys.TRAIN:
gpucopy_op, (inputs, labels) = stage([inputs, labels])
inputs = tf.cast(inputs, model_dtype)
imagenet_mean = np.array([121, 115, 100], dtype=np.float32)
imagenet_std = np.array([70, 68, 71], dtype=np.float32)
inputs = tf.subtract(inputs, imagenet_mean)
inputs = tf.multiply(inputs, 1.0 / imagenet_std)
if model_format == "channels_first":
inputs = tf.transpose(inputs, [0, 3, 1, 2])
with fp32_trainable_vars(regularizer=tf.contrib.layers.l2_regularizer(weight_decay)):
top_layer = model_func(
inputs,
data_format=model_format,
training=is_training,
conv_initializer=conv_init,
adv_bn_init=adv_bn_init,
)
logits = tf.layers.dense(
top_layer, num_classes, kernel_initializer=tf.random_normal_initializer(stddev=0.01)
)
predicted_classes = tf.argmax(logits, axis=1, output_type=tf.int32)
logits = tf.cast(logits, tf.float32)
if mode == tf.estimator.ModeKeys.PREDICT:
probabilities = tf.softmax(logits)
predictions = {
"class_ids": predicted_classes[:, None],
"probabilities": probabilities,
"logits": logits,
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)
loss = tf.identity(
loss, name="loss"
) # For access by logger (TODO: Better way to access it?)
if mode == tf.estimator.ModeKeys.EVAL:
with tf.device(None): # Allow fallback to CPU if no GPU support for these ops
accuracy = tf.metrics.accuracy(labels=labels, predictions=predicted_classes)
top5acc = tf.metrics.mean(tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32))
newaccuracy = (hvd.allreduce(accuracy[0]), accuracy[1])
newtop5acc = (hvd.allreduce(top5acc[0]), top5acc[1])
metrics = {"val-top1acc": newaccuracy, "val-top5acc": newtop5acc}
return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
assert mode == tf.estimator.ModeKeys.TRAIN
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
total_loss = tf.add_n([loss] + reg_losses, name="total_loss")
batch_size = tf.shape(inputs)[0]
global_step = tf.train.get_global_step()
with tf.device("/cpu:0"): # Allow fallback to CPU if no GPU support for these ops
learning_rate = tf.cond(
global_step < warmup_it,
lambda: warmup_decay(warmup_lr, global_step, warmup_it, lr),
lambda: get_lr(
lr,
steps,
lr_steps,
warmup_it,
decay_steps,
global_step,
lr_decay_mode,
cdr_first_decay_ratio,
cdr_t_mul,
cdr_m_mul,
cdr_alpha,
lc_periods,
lc_alpha,
lc_beta,
),
)
learning_rate = tf.identity(learning_rate, "learning_rate")
tf.summary.scalar("learning_rate", learning_rate)
opt = tf.train.MomentumOptimizer(learning_rate, momentum, use_nesterov=True)
opt = hvd.DistributedOptimizer(opt)
if use_larc:
opt = LarcOptimizer(opt, learning_rate, leta, clip=True)
opt = MixedPrecisionOptimizer(opt, scale=loss_scale)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) or []
with tf.control_dependencies(update_ops):
gate_gradients = tf.train.Optimizer.GATE_NONE
train_op = opt.minimize(
total_loss, global_step=tf.train.get_global_step(), gate_gradients=gate_gradients
)
train_op = tf.group(preload_op, gpucopy_op, train_op) # , update_ops)
return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op)