in src/hyperpod_nemo_adapter/utils/fsdp_utils.py [0:0]
def get_auto_wrap_policy(policy: str, transformer_layer=None, use_peft=False):
"""Get auto wrap policy"""
if use_peft:
# to support PEFT, create policy which wraps transformer layers, but also wraps
# linear layers (lambda_policy_fn) and other PEFT layers.
# when using PEFT, the original model's frozen parameters are low precision,
# but the PEFT adapter weights are full fp32 precision. Therefore, the PEFT
# adapter layers must be wrapped separately from frozen layers, to avoid FSDP errors:
# "ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32"
assert (
policy == "transformer_auto_wrap_policy"
), f"PEFT requires 'transformer_auto_wrap_policy' but got '{policy}'"
def lambda_policy_fn(module):
if (
not list(module.named_children())
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
transformer_layer,
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
),
)
return functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
else:
if policy == "transformer_auto_wrap_policy":
return functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(transformer_layer,),
)
elif policy == "size_based_auto_wrap_policy":
return functools.partial(
size_based_auto_wrap_policy,
)
else:
raise NotImplementedError(
f"{policy} is not a valid auto wrap policy, supported policies are: [transformer_auto_wrap_policy, size_based_auto_wrap_policy]"
)