in modules/SwissArmyTransformer/sat/model/base_model.py [0:0]
def from_pretrained(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, use_node_group=True, overwrite_args={}, **kwargs):
if build_only or 'model_parallel_size' not in overwrite_args:
return cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=build_only, overwrite_args=overwrite_args, **kwargs)
else:
new_model_parallel_size = overwrite_args['model_parallel_size']
if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
local_rank = get_node_rank() if use_node_group else get_model_parallel_rank()
world_size = torch.distributed.get_world_size()
assert world_size % new_model_parallel_size == 0, "world size should be a multiplier of new model_parallel_size."
destroy_model_parallel()
initialize_model_parallel(1)
if local_rank == 0:
args.skip_init = True
args.use_gpu_initialization = False
args.device = 'cpu'
overwrite_args.pop('model_parallel_size')
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
if args_.model_parallel_size != 1:
raise Exception("We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!")
if hasattr(args, 'mode') and args.mode == 'inference': # For multi-node inference, we should prevent rank 0 eagerly printing some info.
torch.distributed.barrier()
destroy_model_parallel()
initialize_model_parallel(new_model_parallel_size)
if local_rank == 0:
mp_split_model_rank0(model, model_full, use_node_group=use_node_group)
del model_full
else:
mp_split_model_receive(model, use_node_group=use_node_group)
reset_random_seed(6)
else:
overwrite_args.pop('model_parallel_size')
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert world_size == model_args.model_parallel_size, "world size should be equal to model_parallel_size."
destroy_model_parallel()
initialize_model_parallel(1)
if rank == 0:
args.use_gpu_initialization = False
args.device = 'cpu'
overwrite_args['model_parallel_size'] = 1
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
torch.distributed.barrier()
destroy_model_parallel()
initialize_model_parallel(model_args.model_parallel_size)
if rank == 0:
mp_merge_model_rank0(model, model_full)
model, model_args = model_full, args_
else:
mp_merge_model_send(model)
model_args.model_parallel_size = 1
destroy_model_parallel()
initialize_model_parallel(1)
return model, model_args