def get_auto_wrap_policy()

in src/hyperpod_nemo_adapter/utils/fsdp_utils.py [0:0]


def get_auto_wrap_policy(policy: str, transformer_layer=None, use_peft=False):
    """Get auto wrap policy"""
    if use_peft:
        # to support PEFT, create policy which wraps transformer layers, but also wraps
        # linear layers (lambda_policy_fn) and other PEFT layers.
        # when using PEFT, the original model's frozen parameters are low precision,
        # but the PEFT adapter weights are full fp32 precision. Therefore, the PEFT
        # adapter layers must be wrapped separately from frozen layers, to avoid FSDP errors:
        # "ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32"
        assert (
            policy == "transformer_auto_wrap_policy"
        ), f"PEFT requires 'transformer_auto_wrap_policy' but got '{policy}'"

        def lambda_policy_fn(module):
            if (
                not list(module.named_children())
                and getattr(module, "weight", None) is not None
                and module.weight.requires_grad
            ):
                return True
            return False

        lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
        transformer_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls=(
                transformer_layer,
                PrefixEncoder,
                PromptEncoder,
                PromptEmbedding,
            ),
        )

        return functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
    else:
        if policy == "transformer_auto_wrap_policy":
            return functools.partial(
                transformer_auto_wrap_policy,
                transformer_layer_cls=(transformer_layer,),
            )
        elif policy == "size_based_auto_wrap_policy":
            return functools.partial(
                size_based_auto_wrap_policy,
            )
        else:
            raise NotImplementedError(
                f"{policy} is not a valid auto wrap policy, supported policies are: [transformer_auto_wrap_policy, size_based_auto_wrap_policy]"
            )