in src/nli/inspection_tools.py [0:0]
def get_model_prediction(input_ids, attention_mask, token_type_ids, model, model_class_item, with_gradient=False):
model.eval()
if not with_gradient:
with torch.no_grad():
if model_class_item['model_class_name'] in ["distilbert", "bart-large"]:
outputs = model(input_ids,
attention_mask=attention_mask,
labels=None)
else:
outputs = model(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=None)
else:
if model_class_item['model_class_name'] in ["distilbert", "bart-large"]:
outputs = model(input_ids,
attention_mask=attention_mask,
labels=None)
else:
outputs = model(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=None)
return outputs[0]