def get_model_prediction()

in src/nli/inspection_tools.py [0:0]


def get_model_prediction(input_ids, attention_mask, token_type_ids, model, model_class_item, with_gradient=False):
    model.eval()

    if not with_gradient:
        with torch.no_grad():
            if model_class_item['model_class_name'] in ["distilbert", "bart-large"]:
                outputs = model(input_ids,
                                attention_mask=attention_mask,
                                labels=None)
            else:
                outputs = model(input_ids,
                                attention_mask=attention_mask,
                                token_type_ids=token_type_ids,
                                labels=None)
    else:
        if model_class_item['model_class_name'] in ["distilbert", "bart-large"]:
            outputs = model(input_ids,
                            attention_mask=attention_mask,
                            labels=None)
        else:
            outputs = model(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            labels=None)

    return outputs[0]