in optimizer.py [0:0]
def bs_adafactor(grads, variables, learning_rate, grad_scale=1.0,
beta2=0.999, max_grad_norm=1.0, norm_scale=1.0,
static_loss_scaling=False, **kwargs):
# set to large value to disable clipping, but still collect global norm
# we also use this for dynamic loss scaling
if not max_grad_norm:
max_grad_norm = 9e9
fp16_args = dict(saturate=65504.0,
zero_nans=True) if static_loss_scaling else dict()
global_norm, norm_scale = bs.clip_by_global_norm(grads,
grad_scale=grad_scale,
clip_norm=max_grad_norm,
**fp16_args)
# use Adam for gains/biases
adam = bs.AdamOptimizer(
learning_rate=learning_rate,
beta2=beta2,
norm_scale=norm_scale,
grad_scale=grad_scale,
zero_init_variables=mpi_rank() != 0, **fp16_args)
fact = bs.AdafactorOptimizer(
learning_rate=learning_rate,
beta2=beta2,
norm_scale=norm_scale,
grad_scale=grad_scale,
zero_init_variables=mpi_rank() != 0, **fp16_args)
adam_pairs = list()
fact_pairs = list()
for g, v in zip(grads, variables):
if len(v.shape) < 2:
adam_pairs.append((g, v))
else:
fact_pairs.append((g, v))
adam = adam.apply_gradients(adam_pairs)
fact = fact.apply_gradients(fact_pairs)
return tf.group(adam, fact), global_norm