def gather_vision_embeds()

in server/text_generation_server/models/vlm_causal_lm.py [0:0]


    def gather_vision_embeds(self):
        device = self.input_ids.device
        chunks = []
        for (
            i,
            cache_length,
            input_length,
            request_prefilling,
        ) in zip(
            range(len(self.requests)),
            self.cache_lengths,
            self.input_lengths,
            self.prefilling_mask,
        ):
            if not request_prefilling or self.image_positions[i] is None:
                continue

            for image_position in self.image_positions[i]:
                if image_position is None:
                    continue
                start_pos = image_position.offset
                length = image_position.length

                if start_pos >= cache_length + input_length:
                    # No encoder input required at this step
                    break
                if start_pos + length <= cache_length:
                    # The encode input is already processed
                    continue

                start_idx = max(cache_length - start_pos, 0)
                end_idx = min(cache_length - start_pos + input_length, length)

                assert (
                    image_position.id in self.encoder_cache[i]
                ), f"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}"
                encoder_output = self.encoder_cache[i][image_position.id]

                is_embed = image_position.is_embed
                if is_embed is not None:
                    is_embed = is_embed[start_idx:end_idx]

                from loguru import logger

                logger.info(
                    f"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}"
                )

                embeds = gather_image_embeds(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
                if embeds is not None:
                    chunks.append(embeds)

                if end_idx == length:
                    self.cache_entries_to_free.append((i, image_position.id))
                    self.image_positions[i][image_position.id] = None

        if len(chunks) == 0:
            return None
        return torch.cat(chunks, dim=0).to(device)