def prepare_for_prefill()

in server/text_generation_server/models/vlm_causal_lm.py [0:0]


    def prepare_for_prefill(self):
        super().prepare_for_prefill()

        self.has_image_inputs = False
        self.cache_entries_to_free = []

        self.pixel_values = []

        assert (
            len(self.cache_lengths)
            == len(self.input_lengths)
            == len(self.prefilling_mask)
        ), "Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask"

        for i, (
            cache_length,
            input_length,
            request_prefilling,
        ) in enumerate(
            zip(
                self.cache_lengths,
                self.input_lengths,
                self.prefilling_mask,
            )
        ):
            if not request_prefilling or self.image_positions[i] is None:
                continue

            for image_position in self.image_positions[i]:
                if image_position is None:
                    continue
                start_pos = image_position.offset
                length = image_position.length

                if start_pos >= cache_length + input_length:
                    # No encoder input required at this step
                    break
                if start_pos + length <= cache_length:
                    # The encode input is already processed
                    continue

                self.has_image_inputs = True

                if image_position.id not in self.encoder_cache[i]:
                    image_inputs = self.image_inputs[i][image_position.id]
                    self.pixel_values.append((i, image_position.id, image_inputs))

                    # Remove the image from the image_inputs
                    self.image_inputs[i][image_position.id] = None

        if not self.has_image_inputs:
            self.pixel_values = None
            self.pixel_attention_mask = None
            self.image_sizes = None
            self.image_grid_thw = None
        else:
            image_grid_thw_list = [
                x[2]["image_grid_thw"]
                for x in self.pixel_values
                if "image_grid_thw" in x[2]
            ]
            if image_grid_thw_list:
                self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to(
                    self.input_ids.device
                )
            else:
                self.image_grid_thw = None