in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]
def _setup_smp_delayed_param(self, cfg, model):
# The monkey patch is applied during tsm.init(). This is the make sure the correct import
# is called. ie: RotaryPositionEmbedding will become PatchedRotaryPositionEmbedding.
from torch.sagemaker.delayed_param import DelayedParamIniter
initer = None
if model.do_finetune_with_pretrained_weights:
if self.global_rank != 0:
initer = DelayedParamIniter(model.model)
else:
initer = DelayedParamIniter(model.model)
if not initer:
return None, None, nullcontext()
return (
initer.get_param_init_fn(),
initer.get_post_param_init_fn(),
(
initer.validate_params_and_buffers_inited()
if not model.do_finetune_with_pretrained_weights
else nullcontext()
),
)