in models/vision_language_model.py [0:0]
def forward(self, input_ids, images, attention_mask=None, targets=None):
if isinstance(images, list) and isinstance(images[0], list): # If images is a list of lists, flatten it
images = [img for sublist in images for img in sublist]
images = torch.stack(images).to(input_ids.device)
image_embd = self.vision_encoder(images)
image_embd = self.MP(image_embd) # [num_images, mp_image_token_length, D_lm]
token_embd = self.decoder.token_embedding(input_ids) # [B, T_sequence, D_lm]
updated_token_embd = self._replace_img_tokens_with_embd(input_ids, token_embd, image_embd)
# The updated_token_embd is now the token_embd with image parts replaced.
# The attention_mask comes from the collator and should already cover the full sequence.
logits, _ = self.decoder(updated_token_embd, attention_mask=attention_mask)
loss = None
if targets is not None:
logits = self.decoder.head(logits) # Apply LM head
# Loss is calculated over all tokens, but `targets` (labels) will have -100 for non-answer tokens.
# No need to slice logits based on image embedding size here, as the target mask handles it.
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100)
return logits, loss