in modules/SwissArmyTransformer/sat/model/base_model.py [0:0]
def get_model(args, model_cls, **kwargs):
"""Build the model."""
import torch
from sat.helpers import print_rank0,print_all
from sat import mpu
print_rank0(f'building {model_cls.__name__} model ...')
if 'params_dtype' not in kwargs:
if hasattr(args, 'fp16') and args.fp16:
params_dtype = torch.half
elif hasattr(args, 'bf16') and args.bf16:
params_dtype = torch.bfloat16
else:
params_dtype = torch.float32
else:
# pop params_dtype from kwargs
params_dtype = kwargs.pop('params_dtype')
from sat.helpers import check_if_zero3
if check_if_zero3(args):
import deepspeed
with deepspeed.zero.Init():
model = model_cls(args, params_dtype=params_dtype, **kwargs)
else:
model = model_cls(args, params_dtype=params_dtype, **kwargs)
if mpu.get_data_parallel_rank() == 0:
print_all(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
if hasattr(args, 'fp16') and args.fp16:
model.half()
elif hasattr(args, 'bf16') and args.bf16:
model.bfloat16()
try: # TODO: is this useful?
if not hasattr(args, 'device'):
args.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
model = model.to(args.device)
except Exception as e:
print_all(e)
return model