def load_model()

in 2-dl-container/Container-Root/job/bert/direct_benchmark-gpu.py [0:0]


def load_model(file_name, torchscript):
    # Load modelbase
    with torch.cuda.amp.autocast(enabled=half_precision):
        if torchscript:
            model = torch.jit.load(file_name)
            model.eval()
            model = model.cuda()
        else:
            model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=False)
            model.eval()
            model = model.cuda()

    return model