def mpi_train()

in train.py [0:0]


def mpi_train():

    with tf.device('/cpu:0'), tf.name_scope('optimizer'):
        if H.decay_lr_linearly:
            lr_at_time = H.lr * warmup_linear_decay(H.global_step - H.lr_offset)
        else:
            lr_at_time = H.lr * warmup_cosine(H.global_step - H.lr_offset)
        rcp_mpi_size = tf.constant(1.0 / mpi_size())
        grad_scale = tf.reciprocal(H.curr_loss_scale)

    with tf.device("/gpu:0"):
        avg_loss_gen, _ = model(train=True)
        H.train_gen_loss = avg_loss_gen
        # n_updates_per_epoch H.global_step
        loss_to_optimize = avg_loss_gen
        params = tf.trainable_variables()
        grads = bs.gradients(bs.scale_tensor(loss_to_optimize, H.curr_loss_scale), params)

        if H.merge_layer_allreduce > 0:
            search_strings = list()
            stride = H.merge_layer_allreduce
            for l in range(H.n_layer - 1, -1, -stride):
                search_strings.append([f"model/h{j}" for j in range(l, l - stride, -1)])
        else:
            logprint('Not interleaving allreduce with backprop! Is slow.')
            search_strings = None

        if mpi_size() > 1:

            H.train_gen_loss = allreduce(bs.scale_tensor(avg_loss_gen, rcp_mpi_size))

            # Pre-scale the gradients to give all-reduce some room.
            # After gradients are computed on this device scaling here can be rather aggressive.
            # But 1/mpi_size should be enough.
            grads = [bs.filter_tensor(x, rcp_mpi_size) for x in grads]

            cast_all = tf.float16 if H.fp16_allreduce else None
            grads = group_allreduce(grads, params, search_strings=search_strings, cast_all=cast_all)

            serialize_allreduce_ops([H.train_gen_loss] + grads)

        if H.log_grad_stats and mpi_rank() == 0:
            grads = log_gradient_values(grads, params, H.global_step, model_dir=H.model_dir)

        train_op, global_norm = get_optimizer(H.optimizer)(
            grads, params,
            learning_rate=lr_at_time,
            grad_scale=grad_scale,
            fp16_mean_var=H.fp16_mean_var,
            max_grad_norm=H.max_grad_norm,
            static_loss_scaling=H.float16 and not H.dynamic_loss_scaling,
            beta2=H.beta2)

        if H.l2_loss > 0:
            # AdamW
            logprint('enabling l2 loss with value', H.l2_loss)
            updates = [train_op]
            l2_updates = []
            for p in params:
                if len(shape_list(p)) > 1:
                    l2_updates.append(p.assign(p - lr_at_time * H.l2_loss * p))
            updates.extend(l2_updates)
            train_op = tf.group(*updates)

        if not H.disable_ema_vars:
            # Polyak average of params. Stores an extra copy.
            # NOTE: this assignment is stateful -- graphs created after this will use the EMA var, see
            # the variable getter, so the order of mpi_train and eval model creation cannot be swapped.
            # TODO: remove this constraint
            H.ema = bs.Ema(decay=H.weights_beta)
            with tf.control_dependencies([train_op]):
                train_op = H.ema.apply(params)

    return train_op, lr_at_time, global_norm