def get_params_for_weight_decay_optimization()

in modules/SwissArmyTransformer/sat/training/deepspeed_training.py [0:0]


def get_params_for_weight_decay_optimization(module):
    weight_decay_params = {None: {'params': [], 'lr': 1.}}
    no_weight_decay_params = {None: {'params': [], 'weight_decay': 0.0, 'lr': 1.}}
    print_rank0(f"{NO_WD_MODULES} is set to no_weight_decay")
    for module_ in module.modules():
        if isinstance(module_, tuple(NO_WD_MODULES)):
            for p in module_._parameters.values():
                if p is not None and p.requires_grad:
                    add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True)
        else:
            for n, p in module_._parameters.items():
                if p is not None and n != 'bias' and p.requires_grad:
                    flag = True if hasattr(p, 'no_weight_decay') and p.no_weight_decay else False
                    if flag:
                        print_rank0(f"{n} is set to no_weight_decay")
                        add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True)
                    else:
                        add_param_by_lr(weight_decay_params, p, no_weight_decay=False)
            for n, p in module_._parameters.items():
                if p is not None and n == 'bias' and p.requires_grad:
                    add_param_by_lr(no_weight_decay_params, p, no_weight_decay=True)
    ret = []
    for v in weight_decay_params.values():
        if len(v['params']) != 0:
            ret.append(v)
    for v in no_weight_decay_params.values():
        if len(v['params']) != 0:
            ret.append(v)
    return ret