in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]
def setup(self, *a, **kw):
if self.do_patch_mllama:
patch_mllama_dtype.apply_patch(dtype=torch.bfloat16 if self._cfg.precision == "bf16" else torch.float32)
if self.do_patch_attn_context_parallel:
patch_llama_flash_attn_cp.apply_patch()
if not self.predefined_model:
assert not self.use_smp_model, "model that is not predefined can not support use_smp_model=True"
assert (
self._cfg.get("hf_model_name_or_path", None) is not None
), "hf_model_name_or_path is required when the model is not predefined"
_logger.info(
f"{self._cfg.hf_model_name_or_path} is not a predefined model, most of smp features will be ignored, e.g. TP/fp8, only FSDP/activation_checkpoint can be applied."
)
# Using config from the pretrained model
self.model_config = get_hf_config_from_name_or_path(self._cfg)
else:
self.model_config = self.get_model_config()
# Disable KV cache for HF models
if hasattr(self.model_config, "use_cache"):
self.model_config.use_cache = False
# Adding delayed_param config to HF model config
self.dp_size = dist.get_world_size() // (
self._cfg.get("context_parallel_degree", 1) * self._cfg.get("tensor_model_parallel_degree", 1)
)
self.model_config.delayed_param = self._cfg.delayed_param
model = self._initialize_model(self.model_config)
if self.do_patch_attn_context_parallel:
# check that we are using patched attention for context parallel
assert any(
[submodule.__module__ == "transformer_engine.pytorch.attention" for submodule in model.modules()]
), "This model does not support context parallel with use_smp_model=False."
# setup TransformerEngine CP groups
setup_transformer_engine_cp_groups(
model, get_global_ranks(tsm.state.cp_process_group), tsm.state.cp_process_group
)
if self.do_finetune_with_pretrained_weights:
dist.barrier()
if self.use_smp_model:
self.model = self._transform(model)
else:
self.model = model
if self._cfg.dpo.get("enabled", False) and not self.use_peft:
ref_model = self._initialize_model(self.model_config)
if self.do_patch_attn_context_parallel:
setup_transformer_engine_cp_groups(
ref_model, get_global_ranks(tsm.state.cp_process_group), tsm.state.cp_process_group
)
if self.use_smp_model:
self.ref_model = self._transform(ref_model)
else:
self.ref_model = ref_model
self.ref_model.eval()
self.fp8_recipe = self._fp8_delayed_scaling()