def get_model()

in modules/SwissArmyTransformer/sat/model/base_model.py [0:0]


def get_model(args, model_cls, **kwargs):
    """Build the model."""
    import torch
    from sat.helpers import print_rank0,print_all
    from sat import mpu

    print_rank0(f'building {model_cls.__name__} model ...')
    if 'params_dtype' not in kwargs:
        if hasattr(args, 'fp16') and args.fp16:
            params_dtype = torch.half
        elif hasattr(args, 'bf16') and args.bf16:
            params_dtype = torch.bfloat16
        else:
            params_dtype = torch.float32
    else:
        # pop params_dtype from kwargs
        params_dtype = kwargs.pop('params_dtype')

    from sat.helpers import check_if_zero3
    if check_if_zero3(args):
        import deepspeed
        with deepspeed.zero.Init():
            model = model_cls(args, params_dtype=params_dtype, **kwargs)
    else:
        model = model_cls(args, params_dtype=params_dtype, **kwargs)

    if mpu.get_data_parallel_rank() == 0:
        print_all(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)
    
    if hasattr(args, 'fp16') and args.fp16:
        model.half()
    elif hasattr(args, 'bf16') and args.bf16:
        model.bfloat16()

    try: # TODO: is this useful?
        if not hasattr(args, 'device'):
            args.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
        model = model.to(args.device)
    except Exception as e:
        print_all(e)
    
    return model