def prediction_step()

in jamba1.5-retriever/scripts/train.py [0:0]


    def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):

        print("inside prediction_step")
        print(f"inputs: {inputs}, prediction_loss_only: {prediction_loss_only}, ignore-keys: {ignore_keys}")

        # Extract labels
        labels = inputs["labels"]

        print("will call compute_loss to return loss and embeddings")
        # Compute loss and embeddings for 
        with torch.no_grad():
            loss, (embeddings1, embeddings2) = self.compute_loss(model, inputs, return_outputs=True)
        print("after compute_loss")
        print(f"loss , embeddings1, embeddings2: {loss},{embeddings1},{embeddings2}")

        # Return loss, embeddings, and labels for evaluation
        if prediction_loss_only:
            print("retuning loss only inside prediction_step")
            return loss
        else:
            print("retuning loss embeddings and labels inside prediction_step")
            return (loss, (embeddings1, embeddings2), labels)