in train.py [0:0]
def mpi_train():
with tf.device('/cpu:0'), tf.name_scope('optimizer'):
if H.decay_lr_linearly:
lr_at_time = H.lr * warmup_linear_decay(H.global_step - H.lr_offset)
else:
lr_at_time = H.lr * warmup_cosine(H.global_step - H.lr_offset)
rcp_mpi_size = tf.constant(1.0 / mpi_size())
grad_scale = tf.reciprocal(H.curr_loss_scale)
with tf.device("/gpu:0"):
avg_loss_gen, _ = model(train=True)
H.train_gen_loss = avg_loss_gen
# n_updates_per_epoch H.global_step
loss_to_optimize = avg_loss_gen
params = tf.trainable_variables()
grads = bs.gradients(bs.scale_tensor(loss_to_optimize, H.curr_loss_scale), params)
if H.merge_layer_allreduce > 0:
search_strings = list()
stride = H.merge_layer_allreduce
for l in range(H.n_layer - 1, -1, -stride):
search_strings.append([f"model/h{j}" for j in range(l, l - stride, -1)])
else:
logprint('Not interleaving allreduce with backprop! Is slow.')
search_strings = None
if mpi_size() > 1:
H.train_gen_loss = allreduce(bs.scale_tensor(avg_loss_gen, rcp_mpi_size))
# Pre-scale the gradients to give all-reduce some room.
# After gradients are computed on this device scaling here can be rather aggressive.
# But 1/mpi_size should be enough.
grads = [bs.filter_tensor(x, rcp_mpi_size) for x in grads]
cast_all = tf.float16 if H.fp16_allreduce else None
grads = group_allreduce(grads, params, search_strings=search_strings, cast_all=cast_all)
serialize_allreduce_ops([H.train_gen_loss] + grads)
if H.log_grad_stats and mpi_rank() == 0:
grads = log_gradient_values(grads, params, H.global_step, model_dir=H.model_dir)
train_op, global_norm = get_optimizer(H.optimizer)(
grads, params,
learning_rate=lr_at_time,
grad_scale=grad_scale,
fp16_mean_var=H.fp16_mean_var,
max_grad_norm=H.max_grad_norm,
static_loss_scaling=H.float16 and not H.dynamic_loss_scaling,
beta2=H.beta2)
if H.l2_loss > 0:
# AdamW
logprint('enabling l2 loss with value', H.l2_loss)
updates = [train_op]
l2_updates = []
for p in params:
if len(shape_list(p)) > 1:
l2_updates.append(p.assign(p - lr_at_time * H.l2_loss * p))
updates.extend(l2_updates)
train_op = tf.group(*updates)
if not H.disable_ema_vars:
# Polyak average of params. Stores an extra copy.
# NOTE: this assignment is stateful -- graphs created after this will use the EMA var, see
# the variable getter, so the order of mpi_train and eval model creation cannot be swapped.
# TODO: remove this constraint
H.ema = bs.Ema(decay=H.weights_beta)
with tf.control_dependencies([train_op]):
train_op = H.ema.apply(params)
return train_op, lr_at_time, global_norm