def _setup_smp_delayed_param()

in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]


    def _setup_smp_delayed_param(self, cfg, model):
        # The monkey patch is applied during tsm.init(). This is the make sure the correct import
        # is called. ie: RotaryPositionEmbedding will become PatchedRotaryPositionEmbedding.
        from torch.sagemaker.delayed_param import DelayedParamIniter

        initer = None
        if model.do_finetune_with_pretrained_weights:
            if self.global_rank != 0:
                initer = DelayedParamIniter(model.model)
        else:
            initer = DelayedParamIniter(model.model)

        if not initer:
            return None, None, nullcontext()
        return (
            initer.get_param_init_fn(),
            initer.get_post_param_init_fn(),
            (
                initer.validate_params_and_buffers_inited()
                if not model.do_finetune_with_pretrained_weights
                else nullcontext()
            ),
        )