def setup_model_untrainable_params_and_optimizer()

in modules/SwissArmyTransformer/sat/training/deepspeed_training.py [0:0]


def setup_model_untrainable_params_and_optimizer(args, model, config_params=None):
    """Setup model and optimizer."""

    if hasattr(model, 'disable_untrainable_params'):
        model.disable_untrainable_params() # mark trainable params

    param_groups = get_optimizer_param_groups(model)

    # sync initialized parameters
    # zero3 don't need to sync
    from sat.helpers import check_if_zero3
    if not check_if_zero3(args):
        print_rank0('Syncing initialized parameters...')
        for param_group in param_groups:
            for param in param_group['params']:
                if not param.model_parallel:
                    # We already keep the same random seed for different ranks
                    # However, it is not reliable. Non-model-parallel parameters could be different when initialization.
                    dist.broadcast(
                        param.data,
                        src=0, # group is default group
                    )
                else:
                    dist.broadcast(
                        param.data,
                        src=mpu.get_model_parallel_rank(), # 0 -- mp_size-1
                        group=mpu.get_data_parallel_group() # 1, mp_size + 1, ...
                    )
        print_rank0('Finished syncing initialized parameters.')

    if args.train_data is not None:
        if args.deepspeed:
            from packaging import version
            print_rank0("DeepSpeed is enabled.", level='DEBUG')
            # checking optimizer
            optimizer_name = args.deepspeed_config.get('optimizer',{}).get('type', '')
            if optimizer_name.startswith('sat.'):
                from functools import partial
                from importlib import import_module
                # split and import 
                optimizer_callable = getattr(import_module(optimizer_name.rsplit('.', maxsplit=1)[0]), optimizer_name.split('.')[-1])
                optimizer_callable = partial(optimizer_callable, **args.deepspeed_config.get('optimizer', {}).get('params', {}))
                print_rank0(f'Using optimizer {optimizer_name} from sat.')
                del args.deepspeed_config['optimizer']
            else:
                optimizer_callable = None

            model, optimizer, _, _ = deepspeed.initialize(
                model=model,
                model_parameters=param_groups,
                optimizer=optimizer_callable,
                args=args,
                mpu=mpu,
                dist_init_required=False,
                config_params=args.deepspeed_config
                    if version.parse(deepspeed.version) < version.parse("0.9.0")
                    else None
            )
        else:
            raise ValueError('Currently, we only support training with deepspeed.')
    else:
        optimizer = None

    return model, optimizer