in src/hyperpod_nemo_adapter/collections/model/sagemaker_base_model.py [0:0]
def _transform(self, model):
moe_config = None
load_state_dict_from_rank0 = self.do_finetune_with_pretrained_weights
if self._cfg.moe:
moe_config = MoEConfig(
smp_moe=self.use_smp_model,
random_seed=12345,
moe_load_balancing=self._cfg.moe_load_balancing,
global_token_shuffle=self._cfg.global_token_shuffle,
moe_all_to_all_dispatcher=self._cfg.moe_all_to_all_dispatcher,
moe_aux_loss_coeff=0.001,
moe_z_loss_coeff=0.001,
use_cpu_initialization=self.do_finetune_with_pretrained_weights and dist.get_rank() == 0,
)
if self._cfg.moe and self._cfg.delayed_param and (not load_state_dict_from_rank0 or dist.get_rank() != 0):
with init_empty_weights():
return transform(
model,
config=moe_config,
load_state_dict_from_rank0=load_state_dict_from_rank0,
)
else:
# Note: Current tsm transform() function only allows the config param to be used for MoEConfigs.
return transform(
model,
config=moe_config,
load_state_dict_from_rank0=load_state_dict_from_rank0,
)