in jamba1.5-retriever/scripts/train.py [0:0]
def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
print("inside prediction_step")
print(f"inputs: {inputs}, prediction_loss_only: {prediction_loss_only}, ignore-keys: {ignore_keys}")
# Extract labels
labels = inputs["labels"]
print("will call compute_loss to return loss and embeddings")
# Compute loss and embeddings for
with torch.no_grad():
loss, (embeddings1, embeddings2) = self.compute_loss(model, inputs, return_outputs=True)
print("after compute_loss")
print(f"loss , embeddings1, embeddings2: {loss},{embeddings1},{embeddings2}")
# Return loss, embeddings, and labels for evaluation
if prediction_loss_only:
print("retuning loss only inside prediction_step")
return loss
else:
print("retuning loss embeddings and labels inside prediction_step")
return (loss, (embeddings1, embeddings2), labels)