def _setup_model()

in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]


    def _setup_model(self, model):
        # retrieve the root module name of the model which is the first one.
        use_smp_model = self.use_smp_model
        cfg = self.cfg.model
        predefined_model = model.predefined_model
        if not predefined_model or cfg.get("multi_modal", False) and cfg.model_type == "llama_v3":
            # When running with model that is not predefined or multimodal Llama 3.2
            # we use HF's accelerate to handle the FSDP and activation checkpoint
            # Map to HF name: https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/constants.py#L37
            if cfg.auto_wrap_policy == "transformer_auto_wrap_policy":
                auto_wrap_policy = "transformer_based_wrap"
            elif cfg.auto_wrap_policy == "size_based_auto_wrap_policy":
                auto_wrap_policy = "size_based_wrap"
            else:
                auto_wrap_policy = "no_wrap"
            fsdp_plugin = FullyShardedDataParallelPlugin(auto_wrap_policy=auto_wrap_policy)
            fsdp_plugin.set_auto_wrap_policy(model.model)
            auto_wrap_policy = fsdp_plugin.auto_wrap_policy
        else:
            transformer_layer = get_transformer_layer(cfg.model_type, use_smp_model, cfg.moe, model.peft_type)
            auto_wrap_policy = get_auto_wrap_policy(cfg.auto_wrap_policy, transformer_layer, model.use_peft)
        mixed_precision_policy = set_mixed_precision_recipe(
            precision=cfg.precision,
            use_smp_model=use_smp_model,
            is_qlora=model.use_peft and cfg.peft.get("peft_type", None) == "qlora_4bit",
            cast_forward_inputs=model.use_peft or cfg.get("multi_modal", False),
        )

        sharding_strategy = get_sharding_strategy(cfg.sharding_strategy)
        backward_prefetch = get_backward_fetch_policy(cfg.backward_fetch_policy)
        param_init_fn, post_param_init_fn, model_context = self._setup_delayed_param(cfg, model)

        with (
            model_context,
            tsm_utils.timeit(True, "FSDP constructor", self.global_rank),
        ):
            if dist.get_rank() == 0:
                logging.info(f"Using FSDP plugin with auto_wrap_policy: {auto_wrap_policy}")

            pytorch_model = FSDP(
                module=model.model,
                auto_wrap_policy=auto_wrap_policy,
                mixed_precision=mixed_precision_policy,
                sharding_strategy=sharding_strategy,
                backward_prefetch=backward_prefetch,
                forward_prefetch=cfg.forward_prefetch,
                limit_all_gathers=cfg.limit_all_gathers,
                device_id=torch.cuda.current_device(),
                use_orig_params=cfg.use_orig_param,
                param_init_fn=param_init_fn,
                post_param_init_fn=post_param_init_fn,
                sync_module_states=model.do_finetune_with_pretrained_weights,
                # ignored_modules=ignored_params,
            )
            self._record_fsdp_process_group(pytorch_model)
            self._record_replication_process_group()

        if cfg.activation_checkpointing:
            if not predefined_model:
                # Use native PT API to apply activation checkpoint
                apply_activation_checkpointing(
                    pytorch_model,
                    checkpoint_wrapper_fn=functools.partial(
                        checkpoint_wrapper,
                        checkpoint_impl=CheckpointImpl.NO_REENTRANT,
                    ),
                    auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
                )
            else:
                apply_activation_checkpoint(
                    model=pytorch_model,
                    model_type=cfg.model_type,
                    use_smp_model=use_smp_model,
                    fp8=cfg.fp8,
                    moe=cfg.moe,
                )
        if cfg.get("offload_activations", None):
            pytorch_model = OffloadWrapper(pytorch_model)
        model.model = pytorch_model

        if hasattr(model, "ref_model") and model.ref_model is not None:
            ref_fsdp = FSDP(
                module=model.ref_model,
                auto_wrap_policy=auto_wrap_policy,
                mixed_precision=mixed_precision_policy,
                sharding_strategy=sharding_strategy,
                backward_prefetch=backward_prefetch,
                forward_prefetch=cfg.forward_prefetch,
                limit_all_gathers=cfg.limit_all_gathers,
                device_id=torch.cuda.current_device(),
                use_orig_params=cfg.use_orig_param,
                param_init_fn=param_init_fn,
                post_param_init_fn=post_param_init_fn,
                sync_module_states=model.do_finetune_with_pretrained_weights,
            )
            model.ref_model = ref_fsdp
            model.ref_model.eval()  # Set reference model to evaluation mode

        return model