in modules/SwissArmyTransformer/sat/training/deepspeed_training.py [0:0]
def setup_model_untrainable_params_and_optimizer(args, model, config_params=None):
"""Setup model and optimizer."""
if hasattr(model, 'disable_untrainable_params'):
model.disable_untrainable_params() # mark trainable params
param_groups = get_optimizer_param_groups(model)
# sync initialized parameters
# zero3 don't need to sync
from sat.helpers import check_if_zero3
if not check_if_zero3(args):
print_rank0('Syncing initialized parameters...')
for param_group in param_groups:
for param in param_group['params']:
if not param.model_parallel:
# We already keep the same random seed for different ranks
# However, it is not reliable. Non-model-parallel parameters could be different when initialization.
dist.broadcast(
param.data,
src=0, # group is default group
)
else:
dist.broadcast(
param.data,
src=mpu.get_model_parallel_rank(), # 0 -- mp_size-1
group=mpu.get_data_parallel_group() # 1, mp_size + 1, ...
)
print_rank0('Finished syncing initialized parameters.')
if args.train_data is not None:
if args.deepspeed:
from packaging import version
print_rank0("DeepSpeed is enabled.", level='DEBUG')
# checking optimizer
optimizer_name = args.deepspeed_config.get('optimizer',{}).get('type', '')
if optimizer_name.startswith('sat.'):
from functools import partial
from importlib import import_module
# split and import
optimizer_callable = getattr(import_module(optimizer_name.rsplit('.', maxsplit=1)[0]), optimizer_name.split('.')[-1])
optimizer_callable = partial(optimizer_callable, **args.deepspeed_config.get('optimizer', {}).get('params', {}))
print_rank0(f'Using optimizer {optimizer_name} from sat.')
del args.deepspeed_config['optimizer']
else:
optimizer_callable = None
model, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=param_groups,
optimizer=optimizer_callable,
args=args,
mpu=mpu,
dist_init_required=False,
config_params=args.deepspeed_config
if version.parse(deepspeed.version) < version.parse("0.9.0")
else None
)
else:
raise ValueError('Currently, we only support training with deepspeed.')
else:
optimizer = None
return model, optimizer