in build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/train_utils.py [0:0]
def get_param_groups_by_weight_decay(module):
"""Get param groups."""
weight_decay_params = {"params": []}
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
param_ids = set()
for module_ in module.modules():
# if isinstance(module_, FusedLayerNorm) or
if isinstance(module_, (LayerNorm, LlamaRMSNorm)):
for p in list(
module_._parameters.values()
): # pylint: disable=invalid-name,protected-access
if p is not None and id(p) not in param_ids:
no_weight_decay_params["params"].append(p)
param_ids.add(id(p))
else:
for n, p in list(
module_._parameters.items()
): # pylint: disable=invalid-name,protected-access
if p is not None and n != "bias" and id(p) not in param_ids:
weight_decay_params["params"].append(p)
param_ids.add(id(p))
for n, p in list(
module_._parameters.items()
): # pylint: disable=invalid-name,protected-access
if p is not None and n == "bias" and id(p) not in param_ids:
no_weight_decay_params["params"].append(p)
param_ids.add(id(p))
return weight_decay_params, no_weight_decay_params