def image_embedding()

in maga_transformer/models/llava_vit.py [0:0]


    def image_embedding(self, images: List[Image.Image], mm_type = MMUrlType.IMAGE):
        config = self.config.mm_related_params.config
        image_aspect_ratio = config["image_aspect_ratio"]
        mm_patch_merge_type = config.get("mm_patch_merge_type", "flat")
        mm_newline_position = config.get("mm_newline_position", "one_token")

        processed_images = process_images(images, 
                                          image_aspect_ratio, 
                                          self.vision_tower.image_processor, 
                                          self._device,
                                          self._data_type,
                                          mm_type,
                                          image_grid_pinpoints = config.get("image_grid_pinpoints", []))
        
        processed_images = [image.unsqueeze(0) if image.ndim == 3 else image for image in processed_images]
        split_sizes = [processed_image.shape[0] for processed_image in processed_images]
        processed_images = torch.cat(processed_images)
        image_features = self.encode_images(processed_images)
        image_features = list(torch.split(image_features, split_sizes, dim=0))

        if mm_type == MMUrlType.VIDEO:
            image_features = [self.get_2dPool(feature) for feature in image_features]

        if mm_patch_merge_type == "flat":
            image_features = [x.flatten(0, 1) for x in image_features]
        elif mm_patch_merge_type.startswith("spatial"):
            image_sizes = [image.size for image in images]
            new_image_features = []
            for image_idx, image_feature in enumerate(image_features):
                if mm_type == MMUrlType.VIDEO:  # video operations
                    if mm_newline_position == "grid":
                        image_feature = self.add_token_per_grid(image_feature)
                        if self.config.mm_related_params.config["add_faster_video"]:
                            raise Exception("add_faster_video is not implemented")
                            # faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx])
                            # concat_slow_fater_token = []
                            # for _ in range(image_feature.shape[0]):
                            #     if _ % self.config.faster_token_stride == 0:
                            #         concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
                            #     else:
                            #         concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
                            # image_feature = torch.cat(concat_slow_fater_token)
                        new_image_features.append(image_feature)
                    elif mm_newline_position == "frame":
                        image_feature = self.add_token_per_frame(image_feature)
                        new_image_features.append(image_feature.flatten(0, 1))
                        
                    elif mm_newline_position == "one_token":
                        # one-token
                        image_feature = image_feature.flatten(0, 1)
                        if 'unpad' in mm_patch_merge_type:
                            image_feature = torch.cat((
                                image_feature,
                                self.image_newline[None].to(image_feature.device)
                            ), dim=0)
                        new_image_features.append(image_feature)      
                    elif mm_newline_position == "no_token":
                        new_image_features.append(image_feature.flatten(0, 1))
                    else:
                        raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}")

                elif image_feature.shape[0] > 1:
                    base_image_feature = image_feature[0]
                    image_feature = image_feature[1:]
                    height = width = self.vision_tower.num_patches_per_side
                    assert height * width == base_image_feature.shape[0]

                    if "anyres_max" in image_aspect_ratio:
                            matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
                            if matched_anyres_max_num_patches:
                                max_num_patches = int(matched_anyres_max_num_patches.group(1))

                    if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
                        try:
                            num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], config["image_grid_pinpoints"], self.vision_tower.config.image_size)
                        except Exception as e:
                            logging.error(f"exception {str(e)}, set num_path_width and num_patch_height to 2")
                            num_patch_width, num_patch_height = 2, 2
                        image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
                    else:
                        image_feature = image_feature.view(2, 2, height, width, -1)

                    if "maxpool2x2" in mm_patch_merge_type:
                        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                        image_feature = nn.functional.max_pool2d(image_feature, 2)
                        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                    elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
                        unit = image_feature.shape[2]
                        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                        image_feature = unpad_image(image_feature, image_sizes[image_idx])
                        c, h, w = image_feature.shape
                        times = math.sqrt(h * w / (max_num_patches * unit**2))
                        if times > 1.1:
                            image_feature = image_feature[None]
                            image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
                        image_feature = torch.cat((image_feature, self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
                        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                    elif 'unpad' in mm_patch_merge_type:
                        image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                        image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                        image_feature = unpad_image(image_feature, image_sizes[image_idx])
                        image_feature = torch.cat((
                            image_feature,
                            self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
                        ), dim=-1)
                        image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                    else:
                        image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
                        image_feature = image_feature.flatten(0, 3)

                    if "nobase" in mm_patch_merge_type:
                        pass
                    else:
                        image_feature = torch.cat((base_image_feature, image_feature), dim=0)
                else:
                    image_feature = image_feature[0]
                    if 'unpad' in mm_patch_merge_type:
                        image_feature = torch.cat((
                            image_feature,
                            self.image_newline[None].to(image_feature.device)
                        ), dim=0)
                new_image_features.append(image_feature)
            image_features = new_image_features

        return image_features