backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py [326:540]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        batch.image_inputs = []
        batch.image_positions = []
        batch.encoder_cache = []
        for b in batches:
            if b.image_inputs is not None:
                batch.image_inputs.extend(b.image_inputs)
            else:
                batch.image_inputs.append(None)
            if b.image_positions is not None:
                batch.image_positions.extend(b.image_positions)
            else:
                batch.image_positions.append(None)
            if b.encoder_cache is not None:
                batch.encoder_cache.extend(b.encoder_cache)
            else:
                batch.encoder_cache.append(None)

        batch.pixel_values = None
        batch.pixel_attention_mask = None
        batch.image_sizes = None
        batch.image_grid_thw = None
        batch.inputs_embeds = None
        # To be filled in prepare_for_prefill
        batch.has_image_inputs = False
        batch.cache_entries_to_free = []
        return batch

    @tracer.start_as_current_span("filter")
    def filter(self, request_ids: List[int]):
        if len(request_ids) == 0:
            raise ValueError("Batch must have at least one request")

        image_inputs = []
        image_positions = []
        encoder_cache = []

        for request_id in request_ids:
            idx = self.requests_idx_mapping[request_id]
            image_inputs.append(self.image_inputs[idx])
            image_positions.append(self.image_positions[idx])
            encoder_cache.append(self.encoder_cache[idx])

        batch = super().filter(request_ids)
        batch.pixel_values = None
        batch.pixel_attention_mask = None
        batch.image_sizes = None
        batch.image_grid_thw = None
        batch.inputs_embeds = None
        batch.image_inputs = image_inputs
        batch.image_positions = image_positions
        batch.encoder_cache = encoder_cache

        # To be filled in prepare_for_prefill
        batch.has_image_inputs = False
        batch.cache_entries_to_free = []
        return batch

    @classmethod
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
    ):
        kwargs = {}
        if (
            hasattr(processor, "image_processor_class")
            and processor.image_processor_class == "Idefics3ImageProcessor"
        ):
            kwargs["return_row_col_info"] = True

        max_length = 0
        vocab = tokenizer.get_vocab()

        if not hasattr(config, "image_token_index"):
            config.image_token_index = config.image_token_id

        batch_tokenized_inputs: List[List[int]] = []
        batch_image_inputs: List[Optional[List[dict]]] = []
        batch_image_positions: List[Optional[List[ImagePositions]]] = []

        for r in requests:
            text_parts = []
            image_inputs = []
            image_texts = []

            image_id = 0

            for chunk in r.input_chunks.chunks:
                chunk_type = chunk.WhichOneof("chunk")
                if chunk_type == "text":
                    text = preprocess_text(config, chunk.text)
                    text_parts.append(text)
                elif chunk_type == "image":
                    img = Image.open(BytesIO(chunk.image.data))
                    img = preprocess_image(config, img)

                    image_input = processor.image_processor(
                        [img], return_tensors="pt", **kwargs
                    )
                    image_inputs.append(image_input)

                    img_text, img_start_token_str = image_text_replacement(
                        processor, image_input, config
                    )
                    text_parts.append(img_text)

                    image_texts.append([image_id, img_start_token_str, img_text])
                    image_id += 1
                else:
                    raise RuntimeError(f"Invalid chunk type {chunk_type}")

            full_text = image_text_replacement_fixup(config, "".join(text_parts))
            input_ids = tokenizer(
                full_text,
                truncation=True,
                max_length=r.truncate,
                add_special_tokens=(
                    r.add_special_tokens if config.model_type != "paligemma" else False
                ),
            )["input_ids"]
            max_length = max(max_length, len(input_ids))

            if len(image_inputs) > 0:
                img_start_token = vocab[image_texts[0][1]]
                image_positions = cls.get_image_positions(
                    input_ids, image_texts, img_start_token, config, tokenizer
                )
            else:
                image_inputs = None
                image_positions = None

            batch_tokenized_inputs.append(input_ids)
            batch_image_inputs.append(image_inputs)
            batch_image_positions.append(image_positions)

        return batch_tokenized_inputs, batch_image_inputs, batch_image_positions

    @classmethod
    def get_image_positions(
        cls,
        input_ids: List[int],
        image_texts: List[Tuple[int, str, str]],
        img_start_token: int,
        config,
        tokenizer: PreTrainedTokenizerBase,
    ) -> List[ImagePositions]:
        image_positions = []
        num_images = len(image_texts)

        input_ids_t = torch.as_tensor(input_ids)
        img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
        num_tokens = input_ids_t.numel()

        last_pos = 0
        for i in range(num_images):
            image_id, img_start_token_str, img_text = image_texts[i]
            img_text = image_text_replacement_fixup(config, img_text)

            if config.model_type == "gemma3":
                img_text = img_text.replace("\n\n", "")

            tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
                "input_ids"
            ][0]
            length = tokens.numel()

            assert (
                length <= num_tokens
            ), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"

            pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
            index = img_start_token_pos[pos]
            assert torch.equal(
                input_ids_t[index : index + length], tokens
            ), "Image tokens not found in input_ids"

            is_embed = tokens == config.image_token_index
            num_placeholder_tokens = int(is_embed.sum())
            if num_placeholder_tokens == length:
                is_embed = None

            pos = ImagePositions(
                offset=index,
                length=length,
                id=image_id,
                num_placeholder_tokens=num_placeholder_tokens,
                is_embed=is_embed,
            )

            image_positions.append(pos)
            last_pos = index + length

            if (
                config.model_type == "idefics2"
                and i + 1 != num_images
                and input_ids[last_pos] == config.image_token_index
            ):
                fake_token = last_pos - 1
                fake_token_index = torch.searchsorted(
                    img_start_token_pos, fake_token, right=False
                )
                img_start_token_pos[fake_token_index] = last_pos
                image_texts[i + 1][2] = image_texts[i + 1][2][
                    len(img_start_token_str) :
                ]

        return image_positions

    @classmethod
    def from_pb_processor(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        processor,
        config,
        dtype: torch.dtype,
        device: torch.device,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



server/text_generation_server/models/vlm_causal_lm.py [318:535]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        batch.image_inputs = []
        batch.image_positions = []
        batch.encoder_cache = []
        for b in batches:
            if b.image_inputs is not None:
                batch.image_inputs.extend(b.image_inputs)
            else:
                batch.image_inputs.append(None)
            if b.image_positions is not None:
                batch.image_positions.extend(b.image_positions)
            else:
                batch.image_positions.append(None)
            if b.encoder_cache is not None:
                batch.encoder_cache.extend(b.encoder_cache)
            else:
                batch.encoder_cache.append(None)

        batch.pixel_values = None
        batch.pixel_attention_mask = None
        batch.image_sizes = None
        batch.image_grid_thw = None
        batch.inputs_embeds = None

        # To be filled in prepare_for_prefill
        batch.has_image_inputs = False
        batch.cache_entries_to_free = []

        return batch

    @tracer.start_as_current_span("filter")
    def filter(self, request_ids: List[int]):
        if len(request_ids) == 0:
            raise ValueError("Batch must have at least one request")

        image_inputs = []
        image_positions = []
        encoder_cache = []

        for request_id in request_ids:
            idx = self.requests_idx_mapping[request_id]
            image_inputs.append(self.image_inputs[idx])
            image_positions.append(self.image_positions[idx])
            encoder_cache.append(self.encoder_cache[idx])

        batch = super().filter(request_ids)
        batch.pixel_values = None
        batch.pixel_attention_mask = None
        batch.image_sizes = None
        batch.image_grid_thw = None
        batch.inputs_embeds = None

        batch.image_inputs = image_inputs
        batch.image_positions = image_positions
        batch.encoder_cache = encoder_cache

        # To be filled in prepare_for_prefill
        batch.has_image_inputs = False
        batch.cache_entries_to_free = []
        return batch

    @classmethod
    def batch_tokenized_inputs(
        cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
    ):
        kwargs = {}
        if (
            hasattr(processor, "image_processor_class")
            and processor.image_processor_class == "Idefics3ImageProcessor"
        ):
            kwargs["return_row_col_info"] = True

        max_length = 0
        vocab = tokenizer.get_vocab()

        if not hasattr(config, "image_token_index"):
            config.image_token_index = config.image_token_id

        batch_tokenized_inputs: List[List[int]] = []
        batch_image_inputs: List[Optional[List[dict]]] = []
        batch_image_positions: List[Optional[List[ImagePositions]]] = []

        for r in requests:
            text_parts = []
            image_inputs = []
            image_texts = []

            image_id = 0

            for chunk in r.input_chunks.chunks:
                chunk_type = chunk.WhichOneof("chunk")
                if chunk_type == "text":
                    text = preprocess_text(config, chunk.text)
                    text_parts.append(text)
                elif chunk_type == "image":
                    img = Image.open(BytesIO(chunk.image.data))
                    img = preprocess_image(config, img)

                    image_input = processor.image_processor(
                        [img], return_tensors="pt", **kwargs
                    )
                    image_inputs.append(image_input)

                    img_text, img_start_token_str = image_text_replacement(
                        processor, image_input, config
                    )
                    text_parts.append(img_text)

                    image_texts.append([image_id, img_start_token_str, img_text])
                    image_id += 1
                else:
                    raise RuntimeError(f"Invalid chunk type {chunk_type}")

            full_text = image_text_replacement_fixup(config, "".join(text_parts))
            input_ids = tokenizer(
                full_text,
                truncation=True,
                max_length=r.truncate,
                add_special_tokens=(
                    r.add_special_tokens if config.model_type != "paligemma" else False
                ),
            )["input_ids"]
            max_length = max(max_length, len(input_ids))

            if len(image_inputs) > 0:
                img_start_token = vocab[image_texts[0][1]]
                image_positions = cls.get_image_positions(
                    input_ids, image_texts, img_start_token, config, tokenizer
                )
            else:
                image_inputs = None
                image_positions = None

            batch_tokenized_inputs.append(input_ids)
            batch_image_inputs.append(image_inputs)
            batch_image_positions.append(image_positions)

        return batch_tokenized_inputs, batch_image_inputs, batch_image_positions

    @classmethod
    def get_image_positions(
        cls,
        input_ids: List[int],
        image_texts: List[Tuple[int, str, str]],
        img_start_token: int,
        config,
        tokenizer: PreTrainedTokenizerBase,
    ) -> List[ImagePositions]:
        image_positions = []
        num_images = len(image_texts)

        input_ids_t = torch.as_tensor(input_ids)
        img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]
        num_tokens = input_ids_t.numel()

        last_pos = 0
        for i in range(num_images):
            image_id, img_start_token_str, img_text = image_texts[i]
            img_text = image_text_replacement_fixup(config, img_text)

            if config.model_type == "gemma3":
                img_text = img_text.replace("\n\n", "")

            tokens = tokenizer(img_text, add_special_tokens=False, return_tensors="pt")[
                "input_ids"
            ][0]
            length = tokens.numel()

            assert (
                length <= num_tokens
            ), f"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens"

            pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)
            index = img_start_token_pos[pos]
            assert torch.equal(
                input_ids_t[index : index + length], tokens
            ), "Image tokens not found in input_ids"

            is_embed = tokens == config.image_token_index
            num_placeholder_tokens = int(is_embed.sum())
            if num_placeholder_tokens == length:
                is_embed = None

            pos = ImagePositions(
                offset=index,
                length=length,
                id=image_id,
                num_placeholder_tokens=num_placeholder_tokens,
                is_embed=is_embed,
            )

            image_positions.append(pos)
            last_pos = index + length

            if (
                config.model_type == "idefics2"
                and i + 1 != num_images
                and input_ids[last_pos] == config.image_token_index
            ):
                fake_token = last_pos - 1
                fake_token_index = torch.searchsorted(
                    img_start_token_pos, fake_token, right=False
                )
                img_start_token_pos[fake_token_index] = last_pos
                image_texts[i + 1][2] = image_texts[i + 1][2][
                    len(img_start_token_str) :
                ]

        return image_positions

    @classmethod
    def from_pb_processor(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        processor,
        config,
        dtype: torch.dtype,
        device: torch.device,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



