in jamba1.5-retriever/scripts/train.py [0:0]
def compute_loss(self, model, inputs, return_outputs=False):
print("Inside Compute_loss")
print("inputs: ", inputs)
print("return_ouputs: ", return_outputs)
# Extract labels
labels = inputs["labels"]
# Extract input_ids and attention_mask for sentence1 and sentence2
sentence1_input_ids, sentence2_input_ids = inputs['input_ids1'], inputs['input_ids2']
sentence1_attention_mask, sentence2_attention_mask = inputs['attention_mask1'], inputs['attention_mask2']
# Pass sentence1 and sentence2 through the model to get embeddings
outputs1 = model(input_ids=sentence1_input_ids, attention_mask=sentence1_attention_mask)
print("Outputs1: ", outputs1)
outputs2 = model(input_ids=sentence2_input_ids, attention_mask=sentence2_attention_mask)
print("Outputs2: ", outputs2)
# Pool embeddings
print("Before mean_pooling call for Embeddings1")
embeddings1 = mean_pooling(outputs1.last_hidden_state, sentence1_attention_mask)
print("embeddings1: ", embeddings1)
print("Before mean_pooling call for Embeddings2")
embeddings2 = mean_pooling(outputs2.last_hidden_state, sentence2_attention_mask)
print("embeddings2: ", embeddings2)
print("Before constrastive_loss call")
# Calculate contrastive loss
loss = contrastive_loss(embeddings1, embeddings2, labels)
print("After constrastive_loss call")
print("Return for compute_loss: ", (loss, (embeddings1, embeddings2)))
if return_outputs:
print("Return loss and embeddings")
return (loss, (embeddings1, embeddings2))
else:
print("Return only loss")
return loss