in src/hyperpod_nemo_adapter/collections/model/nlp/sagemaker_deepseek_model.py [0:0]
def _build_model_from_pretrain(self, model_cfg, torch_dtype=None, quantization_config=None):
path = self._cfg.hf_model_name_or_path
_logger.info("Loading pretrained weights from %s.", path)
use_flash_attn = self._cfg.use_flash_attention
attn = "flash_attention_2"
access_token = self._cfg.get("hf_access_token", None)
if TF_VERSION < pversion.parse("4.37.1") or not use_flash_attn:
return DeepseekV3ForCausalLM.from_pretrained(
pretrained_model_name_or_path=path,
config=model_cfg,
quantization_config=quantization_config,
torch_dtype=torch_dtype,
token=access_token,
trust_remote_code=True,
)
return DeepseekV3ForCausalLM.from_pretrained(
pretrained_model_name_or_path=path,
attn_implementation=attn,
config=model_cfg,
quantization_config=quantization_config,
torch_dtype=torch_dtype,
token=access_token,
trust_remote_code=True,
)