in modules/SwissArmyTransformer/sat/training/deepspeed_training.py [0:0]
def get_params_for_weight_decay_optimization(module):
weight_decay_params = {None: {'params': [], 'lr': 1.}}
no_weight_decay_params = {None: {'params': [], 'weight_decay': 0.0, 'lr': 1.}}
print_rank0(f"{NO_WD_MODULES} is set to no_weight_decay")
for module_ in module.modules():
if isinstance(module_, tuple(NO_WD_MODULES)):
for p in module_._parameters.values():
if p is not None and p.requires_grad:
add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True)
else:
for n, p in module_._parameters.items():
if p is not None and n != 'bias' and p.requires_grad:
flag = True if hasattr(p, 'no_weight_decay') and p.no_weight_decay else False
if flag:
print_rank0(f"{n} is set to no_weight_decay")
add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True)
else:
add_param_by_lr(weight_decay_params, p, no_weight_decay=False)
for n, p in module_._parameters.items():
if p is not None and n == 'bias' and p.requires_grad:
add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True)
ret = []
for v in weight_decay_params.values():
if len(v['params']) != 0:
ret.append(v)
for v in no_weight_decay_params.values():
if len(v['params']) != 0:
ret.append(v)
return ret