in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]
def _setup_model(self, model):
# retrieve the root module name of the model which is the first one.
use_smp_model = self.use_smp_model
cfg = self.cfg.model
predefined_model = model.predefined_model
if not predefined_model or cfg.get("multi_modal", False) and cfg.model_type == "llama_v3":
# When running with model that is not predefined or multimodal Llama 3.2
# we use HF's accelerate to handle the FSDP and activation checkpoint
# Map to HF name: https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/constants.py#L37
if cfg.auto_wrap_policy == "transformer_auto_wrap_policy":
auto_wrap_policy = "transformer_based_wrap"
elif cfg.auto_wrap_policy == "size_based_auto_wrap_policy":
auto_wrap_policy = "size_based_wrap"
else:
auto_wrap_policy = "no_wrap"
fsdp_plugin = FullyShardedDataParallelPlugin(auto_wrap_policy=auto_wrap_policy)
fsdp_plugin.set_auto_wrap_policy(model.model)
auto_wrap_policy = fsdp_plugin.auto_wrap_policy
else:
transformer_layer = get_transformer_layer(cfg.model_type, use_smp_model, cfg.moe, model.peft_type)
auto_wrap_policy = get_auto_wrap_policy(cfg.auto_wrap_policy, transformer_layer, model.use_peft)
mixed_precision_policy = set_mixed_precision_recipe(
precision=cfg.precision,
use_smp_model=use_smp_model,
is_qlora=model.use_peft and cfg.peft.get("peft_type", None) == "qlora_4bit",
cast_forward_inputs=model.use_peft or cfg.get("multi_modal", False),
)
sharding_strategy = get_sharding_strategy(cfg.sharding_strategy)
backward_prefetch = get_backward_fetch_policy(cfg.backward_fetch_policy)
param_init_fn, post_param_init_fn, model_context = self._setup_delayed_param(cfg, model)
with (
model_context,
tsm_utils.timeit(True, "FSDP constructor", self.global_rank),
):
if dist.get_rank() == 0:
logging.info(f"Using FSDP plugin with auto_wrap_policy: {auto_wrap_policy}")
pytorch_model = FSDP(
module=model.model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=sharding_strategy,
backward_prefetch=backward_prefetch,
forward_prefetch=cfg.forward_prefetch,
limit_all_gathers=cfg.limit_all_gathers,
device_id=torch.cuda.current_device(),
use_orig_params=cfg.use_orig_param,
param_init_fn=param_init_fn,
post_param_init_fn=post_param_init_fn,
sync_module_states=model.do_finetune_with_pretrained_weights,
# ignored_modules=ignored_params,
)
self._record_fsdp_process_group(pytorch_model)
self._record_replication_process_group()
if cfg.activation_checkpointing:
if not predefined_model:
# Use native PT API to apply activation checkpoint
apply_activation_checkpointing(
pytorch_model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
)
else:
apply_activation_checkpoint(
model=pytorch_model,
model_type=cfg.model_type,
use_smp_model=use_smp_model,
fp8=cfg.fp8,
moe=cfg.moe,
)
if cfg.get("offload_activations", None):
pytorch_model = OffloadWrapper(pytorch_model)
model.model = pytorch_model
if hasattr(model, "ref_model") and model.ref_model is not None:
ref_fsdp = FSDP(
module=model.ref_model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=sharding_strategy,
backward_prefetch=backward_prefetch,
forward_prefetch=cfg.forward_prefetch,
limit_all_gathers=cfg.limit_all_gathers,
device_id=torch.cuda.current_device(),
use_orig_params=cfg.use_orig_param,
param_init_fn=param_init_fn,
post_param_init_fn=post_param_init_fn,
sync_module_states=model.do_finetune_with_pretrained_weights,
)
model.ref_model = ref_fsdp
model.ref_model.eval() # Set reference model to evaluation mode
return model