def forward()

in models/vision_language_model.py [0:0]


    def forward(self, input_ids, images, attention_mask=None, targets=None):
        if isinstance(images, list) and isinstance(images[0], list):  # If images is a list of lists, flatten it
            images = [img for sublist in images for img in sublist]
            images = torch.stack(images).to(input_ids.device)
        image_embd = self.vision_encoder(images)
        image_embd = self.MP(image_embd) # [num_images, mp_image_token_length, D_lm]

        token_embd = self.decoder.token_embedding(input_ids) # [B, T_sequence, D_lm]

        updated_token_embd = self._replace_img_tokens_with_embd(input_ids, token_embd, image_embd)

        # The updated_token_embd is now the token_embd with image parts replaced.
        # The attention_mask comes from the collator and should already cover the full sequence.
        logits, _ = self.decoder(updated_token_embd, attention_mask=attention_mask)

        loss = None
        if targets is not None:
            logits = self.decoder.head(logits) # Apply LM head
            # Loss is calculated over all tokens, but `targets` (labels) will have -100 for non-answer tokens.
            # No need to slice logits based on image embedding size here, as the target mask handles it.
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100)

        return logits, loss