in maga_transformer/models/llava_vit.py [0:0]
def image_embedding(self, images: List[Image.Image], mm_type = MMUrlType.IMAGE):
config = self.config.mm_related_params.config
image_aspect_ratio = config["image_aspect_ratio"]
mm_patch_merge_type = config.get("mm_patch_merge_type", "flat")
mm_newline_position = config.get("mm_newline_position", "one_token")
processed_images = process_images(images,
image_aspect_ratio,
self.vision_tower.image_processor,
self._device,
self._data_type,
mm_type,
image_grid_pinpoints = config.get("image_grid_pinpoints", []))
processed_images = [image.unsqueeze(0) if image.ndim == 3 else image for image in processed_images]
split_sizes = [processed_image.shape[0] for processed_image in processed_images]
processed_images = torch.cat(processed_images)
image_features = self.encode_images(processed_images)
image_features = list(torch.split(image_features, split_sizes, dim=0))
if mm_type == MMUrlType.VIDEO:
image_features = [self.get_2dPool(feature) for feature in image_features]
if mm_patch_merge_type == "flat":
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith("spatial"):
image_sizes = [image.size for image in images]
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if mm_type == MMUrlType.VIDEO: # video operations
if mm_newline_position == "grid":
image_feature = self.add_token_per_grid(image_feature)
if self.config.mm_related_params.config["add_faster_video"]:
raise Exception("add_faster_video is not implemented")
# faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx])
# concat_slow_fater_token = []
# for _ in range(image_feature.shape[0]):
# if _ % self.config.faster_token_stride == 0:
# concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
# else:
# concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
# image_feature = torch.cat(concat_slow_fater_token)
new_image_features.append(image_feature)
elif mm_newline_position == "frame":
image_feature = self.add_token_per_frame(image_feature)
new_image_features.append(image_feature.flatten(0, 1))
elif mm_newline_position == "one_token":
# one-token
image_feature = image_feature.flatten(0, 1)
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat((
image_feature,
self.image_newline[None].to(image_feature.device)
), dim=0)
new_image_features.append(image_feature)
elif mm_newline_position == "no_token":
new_image_features.append(image_feature.flatten(0, 1))
else:
raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}")
elif image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.vision_tower.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
try:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], config["image_grid_pinpoints"], self.vision_tower.config.image_size)
except Exception as e:
logging.error(f"exception {str(e)}, set num_path_width and num_patch_height to 2")
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
image_feature = image_feature.view(2, 2, height, width, -1)
if "maxpool2x2" in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = nn.functional.max_pool2d(image_feature, 2)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
unit = image_feature.shape[2]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
image_feature = torch.cat((image_feature, self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif 'unpad' in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat((
image_feature,
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
if "nobase" in mm_patch_merge_type:
pass
else:
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat((
image_feature,
self.image_newline[None].to(image_feature.device)
), dim=0)
new_image_features.append(image_feature)
image_features = new_image_features
return image_features