in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]
def __init__(self, cfg: DictConfig, trainer: Trainer, use_smp_model=True):
self.grad_norm = None
self._cfg = cfg
self.model = None
self.ref_model = None
self.use_smp_model = use_smp_model
self.model_config = None
self.val_loss = 0
self._config_mapping_hf_to_recipe_aliases = None
self.set_config_mapping_hf_to_recipe_aliases()
# Setup Transformer Engine Variable
os.environ["NVTE_TORCH_COMPILE"] = "0"
os.environ["TORCH_COMPILE_DISABLE"] = "1"
if self.do_patch_attn_context_parallel:
# avoid error trying to access non-existent attribute in TE extra state
# in `from_pretrained`
os.environ["ACCELERATE_USE_FSDP"] = "True"
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "True"
if self._cfg.get("nvte_attn_backend", None) is not None:
if self._cfg.nvte_attn_backend == "fused":
# use fused-attn backend
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FLASH_ATTN"] = "0"
elif self._cfg.nvte_attn_backend == "flash":
# use flash-attn backend
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_FLASH_ATTN"] = "1"
super().__init__(cfg, trainer)