def compute_loss()

in jamba1.5-retriever/scripts/train.py [0:0]


    def compute_loss(self, model, inputs, return_outputs=False):

        print("Inside Compute_loss")
        print("inputs: ", inputs)
        print("return_ouputs: ", return_outputs)

        # Extract labels
        labels = inputs["labels"]

        # Extract input_ids and attention_mask for sentence1 and sentence2
        sentence1_input_ids, sentence2_input_ids = inputs['input_ids1'], inputs['input_ids2']
        sentence1_attention_mask, sentence2_attention_mask = inputs['attention_mask1'], inputs['attention_mask2']

        # Pass sentence1 and sentence2 through the model to get embeddings
        outputs1 = model(input_ids=sentence1_input_ids, attention_mask=sentence1_attention_mask)
        print("Outputs1: ", outputs1)

        outputs2 = model(input_ids=sentence2_input_ids, attention_mask=sentence2_attention_mask)
        print("Outputs2: ", outputs2)

        # Pool embeddings
        print("Before mean_pooling call for Embeddings1")
        embeddings1 = mean_pooling(outputs1.last_hidden_state, sentence1_attention_mask)
        print("embeddings1: ", embeddings1)

        print("Before mean_pooling call for Embeddings2") 
        embeddings2 = mean_pooling(outputs2.last_hidden_state, sentence2_attention_mask)
        print("embeddings2: ", embeddings2)

        print("Before constrastive_loss call")
        # Calculate contrastive loss
        loss = contrastive_loss(embeddings1, embeddings2, labels)
        print("After constrastive_loss call")
        print("Return for compute_loss: ", (loss, (embeddings1, embeddings2)))

        if return_outputs:
            print("Return loss and embeddings")
            return (loss, (embeddings1, embeddings2))
        else:
            print("Return only loss")
            return loss