megatron_patch/data/dataset_helpers.py (381 lines of code) (raw):
# Copyright (c) 2024 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import re
import sys
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from collections import defaultdict
from megatron_patch.data.image_processing import get_visual_transform
import numpy as np
import torch
from torchvision import transforms as T
import json
from megatron.energon import (
Batch,
DefaultTaskEncoder,
VQASample,
)
from megatron_patch.data.energon.chatml import ChatMLSample
from megatron.training import get_args
from megatron_patch.tokenizer import get_tokenizer
# Type for intermediate batch, after batch()
@dataclass
class ImageTaskSample:
__key__: str
__subflavors__: Dict
imgs: List[np.ndarray] # (c, h, w)
videos: List[np.ndarray] # (c, h, w)
image_thw_grids: np.ndarray
video_thw_grids: np.ndarray
image_input_mask: np.ndarray
video_input_mask: np.ndarray
second_per_grid_ts: np.ndarray # (n_videos, )
text: np.ndarray
target: np.ndarray
# Typing for the resulting batch data after encode_batch()
@dataclass
class VQATaskBatch(Batch):
__keys__: List[str]
__subflavors__: List[Dict]
# (num_tiles, c, h, w)
imgs: torch.Tensor
videos: torch.Tensor
image_thw_grids: torch.Tensor
video_thw_grids: torch.Tensor
image_input_mask: torch.Tensor
video_input_mask: torch.Tensor
second_per_grid_ts: torch.Tensor # (n_videos, ), read from metadata?
# (n, seq_len)
text: torch.Tensor
# (n, seq_len)
target: torch.Tensor
class InternalWarning(Warning):
...
def convert_to_qwen2vl_content(
user_input: str,
image_pattern: str = '<image>',
video_pattern: str = '<video>'
):
"""
Split user input into format Qwen2VL tokenizer accepts.
"""
pattern = r"({image}|{video})".format(image=image_pattern, video=video_pattern)
contents = []
cur = 0
mm_idx = defaultdict(int)
for matched in re.finditer(pattern, user_input):
start, end = matched.span()
if start > cur:
contents.append({
"type": "text",
"text": user_input[cur:start].strip()
})
contents.append({
"type": matched.string[start:end][1:-1],
matched.string[start:end][1:-1]: str(mm_idx[matched.string[start:end][1:-1]])
})
cur = end
mm_idx[matched.string[start:end][1:-1]] += 1
if cur < len(user_input):
contents.append({
"type": "text",
"text": user_input[cur:len(user_input)].strip()
})
return contents
class TaskEncoder(DefaultTaskEncoder[Union[VQASample, ChatMLSample], ImageTaskSample, VQATaskBatch, dict]):
"""A simple task encoder for captioning."""
def __init__(
self,
):
# Specify the batch_type for default batching (batching is performed here "manually" by
# overwriting the `batch` method)
super().__init__()
self.args = get_args()
self.tokenizer = get_tokenizer()
self.temporal_patch_size = self.args.temporal_patch_size
self.merge_size = self.args.spatial_merge_size
self.patch_size = self.args.patch_size
self.seq_len = self.args.max_padding_length
def encode_sample(self, sample: Union[VQASample, ChatMLSample]):
if isinstance(sample, VQASample):
is_llava_training = sample.__subflavors__['is_llava_training'] if 'is_llava_training' in sample.__subflavors__ else False
if is_llava_training:
raise NotImplementedError('Sample format not supported')
else:
yield self.encode_vqa(sample)
elif isinstance(sample, ChatMLSample):
yield self.encode_chatml(sample)
else:
raise NotImplementedError('Sample format not supported')
def _flatten_visual_inputs(self, visuals, is_image: bool = True):
flattened = []
thw_grids = []
for visual in visuals:
if is_image:
resized_height, resized_width = visual.shape[-2:]
patches = np.tile(np.array(visual), (self.temporal_patch_size, 1, 1, 1))
else:
assert len(visual) % self.temporal_patch_size == 0
patches = np.array(visual)
resized_height, resized_width = patches.shape[-2:]
channel = patches.shape[1]
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
patches = patches.reshape(
grid_t,
self.temporal_patch_size,
channel,
grid_h // self.merge_size,
self.merge_size,
self.patch_size,
grid_w // self.merge_size,
self.merge_size,
self.patch_size,
)
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
)
flattened.append(flatten_patches)
thw_grids.append((grid_t, grid_h, grid_w))
return flattened, np.array(thw_grids)
def encode_chatml(self, sample: ChatMLSample):
# TODO: modify get_visual_transform to add more augmentations
imgs = [get_visual_transform(img)[0] for img in sample.imgs]
videos = [[get_visual_transform(frame)[0] for frame in video] for video in sample.videos]
# NOTE: make n_frames even foreach video
for i, video in enumerate(videos):
videos[i] = video[:len(video) // 2 * 2]
# NOTE: flatten all images
flattened_imgs, image_thw_grids = self._flatten_visual_inputs(imgs, is_image=True)
flattened_videos, video_thw_grids = self._flatten_visual_inputs(videos, is_image=False)
# NOTE: generate qwen2vl conversations
conversation = json.loads(sample.conversation) if isinstance(sample.conversation, (str, bytes)) else sample.conversation
second_per_grid_ts = [1 / 2.0] * len(video_thw_grids)
if 'conversations' in conversation:
second_per_grid_ts = conversation.get('second_per_grid_ts', second_per_grid_ts)
second_per_grid_ts = [float(i) for i in second_per_grid_ts]
conversation = conversation['conversations']
role_key = 'from' if 'from' in conversation[0] else 'role'
content_key = 'value' if 'from' in conversation[0] else 'content'
# NOTE: assume the conversation format is: [System]? (User Assistant)+
converted_conversation = []
if len(conversation) % 2 == 0:
# Default Prompt
converted_conversation.append({
'role': 'system',
'content': 'You are a helpful assistant.'
})
else:
converted_conversation.append({
'role': 'system',
'content': conversation[0][content_key]
})
conversation = conversation[1:]
EXPECTED_ROLE = ['human', 'gpt']
for turn_idx, turn in enumerate(conversation):
role = turn[role_key]
if role != EXPECTED_ROLE[turn_idx % len(EXPECTED_ROLE)]:
raise InternalWarning(f"Expect conversation organized in order: [sys] human gpt human gpt..., but got role '{role}' in turn {turn_idx}")
content = turn[content_key]
if role == 'human':
role = 'user'
content = convert_to_qwen2vl_content(content)
elif role == 'gpt':
role = 'assistant'
converted_conversation.append({
'role': role,
'content': content
})
conversation = converted_conversation
# NOTE: we need to mask all system/user input tokens and assistant generation prefix tokens
input_ids = self.tokenizer.apply_chat_template(conversation, tokenize=True, return_tensors="np")[0]
target = input_ids.copy()
system_prompt_prefix = len(self.tokenizer.apply_chat_template([conversation[0]], tokenize=True))
assistant_generation_prefix = 3
pad_token_id = self.tokenizer.pad_token_id
target[:system_prompt_prefix] = pad_token_id
offset = system_prompt_prefix
for turn_idx, turn in enumerate(conversation[1:]):
turn_tokens = self.tokenizer.apply_chat_template([turn], tokenize=True, return_tensors="np")[0]
turn_content = turn_tokens[system_prompt_prefix:]
n_tokens = len(turn_content)
if (target[offset: offset + n_tokens] != turn_content).any():
raise InternalWarning("Encode Error")
if turn['role'] == 'user':
target[offset: offset + n_tokens] = pad_token_id
elif turn['role'] == 'assistant':
target[offset: offset + assistant_generation_prefix] = pad_token_id
offset += n_tokens
# NOTE: expand image_pad & video_pad
merge_length = self.merge_size**2
image_token_id, video_token_id = self.tokenizer.encode(['<|image_pad|>', '<|video_pad|>'])
image_token_indices = np.where(input_ids == image_token_id)[0]
assert len(image_token_indices) == len(image_thw_grids), f"With {len(image_thw_grids)} images in the sample, but {len(image_token_indices)} image placeholders!"
video_token_indices = np.where(input_ids == video_token_id)[0]
assert len(video_token_indices) == len(video_thw_grids), f"With {len(video_thw_grids)} images in the sample, but {len(video_token_indices)} video placeholders!"
image_thw_grids, video_thw_grids = np.array(image_thw_grids, dtype=np.int64), np.array(video_thw_grids, dtype=np.int64)
target_length = (
input_ids.shape[0]
- image_thw_grids.shape[0] + image_thw_grids.prod(axis=-1).sum() // merge_length
- video_thw_grids.shape[0] + video_thw_grids.prod(axis=-1).sum() // merge_length
)
if target_length > self.seq_len:
raise InternalWarning(f"Long sequence with length {target_length} found, dropped...")
final_input_ids = np.zeros(target_length, dtype=input_ids.dtype)
final_input_masks = final_input_ids.copy()
image_idx, video_idx = 0, 0
indices = np.sort(np.concatenate([image_token_indices, video_token_indices]))
cur_x, cur_y = 0, 0
for idx in indices:
token_id = input_ids[idx]
if token_id == image_token_id:
size = image_thw_grids[image_idx].prod() // merge_length
image_idx += 1
elif token_id == video_token_id:
size = video_thw_grids[video_idx].prod() // merge_length
video_idx += 1
# NOTE:
# input_ids[cur_x:idx] -> final_input_ids[cur_y:cur_y + idx - cur_x]
# input_ids[idx] -> final_input_ids[cur_y + idx - cur_x: cur_y + idx - cur_x + size]
final_input_ids[cur_y: cur_y + idx - cur_x] = input_ids[cur_x:idx]
final_input_masks[cur_y: cur_y + idx - cur_x] = target[cur_x:idx]
cur_y += idx - cur_x
final_input_ids[cur_y: cur_y + size] = token_id
final_input_masks[cur_y: cur_y + size] = pad_token_id
cur_y += size
cur_x = idx + 1
if cur_x < len(input_ids):
final_input_ids[cur_y:] = input_ids[cur_x:]
final_input_masks[cur_y:] = target[cur_x:]
target = np.roll(final_input_masks, shift=-1)
target[-1] = pad_token_id
if (target == pad_token_id).all():
raise InternalWarning("Sample with all masked label, dropped.")
image_input_mask = final_input_ids == self.tokenizer.image_token_id
video_input_mask = final_input_ids == self.tokenizer.video_token_id
# collect data
return ImageTaskSample(
__key__=sample.__key__,
__subflavors__=sample.__subflavors__,
imgs=flattened_imgs,
videos=flattened_videos,
image_thw_grids=image_thw_grids,
video_thw_grids=video_thw_grids,
second_per_grid_ts = np.array(second_per_grid_ts, dtype=np.float32),
image_input_mask=image_input_mask,
video_input_mask=video_input_mask,
text=final_input_ids,
target=target,
)
def encode_vqa(self, sample: VQASample):
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False
if has_video:
raise NotImplementedError("You should use sharegpt dataset to train with videos.")
else:
# TODO: add args
imgs = get_visual_transform(sample.image)
flatten_patches, thw_grids = self._flatten_visual_inputs(imgs, is_image=True)
assert "<image>" in sample.context # ?
# NOTE: we expect a context is a string with <image> conetnt
if isinstance(sample.answers, list):
answer_list = sample.answers
weight_list = np.array(sample.answer_weights).astype(np.float32)
weight_list = weight_list / np.sum(weight_list)
answer_idx = np.random.choice(weight_list.shape[0], 1, p=weight_list)[0]
answer = answer_list[answer_idx]
else:
answer = sample.answers
conversation = [
{"role": "user", "content": convert_to_qwen2vl_content(sample.context)},
{"role": "assistant", "content": answer},
]
user_inputs = self.tokenizer.apply_chat_template(conversation[:-1], tokenize=False)
text = self.tokenizer.apply_chat_template(conversation, tokenize=False)
# text, target = self.tokenizer.tokenize_conversation(conversation, False, False)
# replace <image> token by <image> * (thw)
merge_length = self.merge_size**2
image_token = '<|image_pad|>'
assert len(thw_grids) == 1, "Only one image per sample is supported!"
index = 0
while image_token in text:
grid_t, grid_h, grid_w = thw_grids[index]
l = grid_t * grid_h * grid_w
text = text.replace(
image_token, "<|placeholder|>" * (l // merge_length), 1
)
user_inputs = user_inputs.replace(
image_token, "<|placeholder|>" * (l // merge_length), 1
)
index += 1
text = text.replace("<|placeholder|>", image_token)
user_inputs = user_inputs.replace("<|placeholder|>", image_token)
input_ids = self.tokenizer.tokenize(text)
user_input_ids = self.tokenizer.tokenize(user_inputs)
if len(input_ids) > self.seq_len:
raise InternalWarning(f"Long sequence with length {len(input_ids)} found, dropped...")
target = np.array(input_ids[1:] + [self.tokenizer.pad_token_id])
if len(user_input_ids) >= len(input_ids):
raise InternalWarning(f"Sample not supported, dropped...")
# ensure user inputs is a prefix of full text
if not (np.array(user_input_ids) == np.array(input_ids[:len(user_input_ids)])).all():
raise InternalWarning(f"Sample not supported, dropped...")
# mask input
target[:len(user_input_ids)-1] = self.tokenizer.pad_token_id
img_token_id = self.tokenizer.image_token_id
image_input_mask = np.array(input_ids) == img_token_id
# collect data
return ImageTaskSample(
__key__=sample.__key__,
__subflavors__=sample.__subflavors__,
imgs=flatten_patches,
videos=list(),
image_thw_grids=thw_grids,
video_thw_grids=torch.empty([0, 3], dtype=torch.long),
image_input_mask=image_input_mask,
video_input_mask=None,
second_per_grid_ts=np.zeros(0, dtype=np.float32),
text=input_ids,
target=target,
)
def batch(self, samples: List[ImageTaskSample]) -> VQATaskBatch:
# Stack images to [num_tiles, c, h, w]. If there are no images (text-only), then use a dummy image.
imgs = [img for s in samples for img in s.imgs]
if len(imgs) > 0:
imgs = torch.cat([torch.from_numpy(img) for img in imgs])
else:
imgs = torch.empty([0, 3 * self.temporal_patch_size * self.patch_size * self.patch_size], dtype=torch.float32)
image_thw_grids = [thw_grids for s in samples for thw_grids in s.image_thw_grids]
if len(image_thw_grids) > 0:
image_thw_grids = torch.from_numpy(np.array(image_thw_grids)).long()
assert image_thw_grids.prod(dim=-1).sum() == imgs.shape[0]
else:
image_thw_grids = torch.empty([0, 3], dtype=torch.long)
# Stack videos to [num_tiles, c, h, w]. If there are no videos (text-only), then use a dummy video.
videos = [video for s in samples for video in s.videos]
if len(videos) > 0:
videos = torch.cat([torch.from_numpy(video) for video in videos])
else:
videos = torch.empty([0, 3 * self.temporal_patch_size * self.patch_size * self.patch_size], dtype=torch.float32)
second_per_grid_ts = [second_per_grid for s in samples for second_per_grid in s.second_per_grid_ts]
if len(second_per_grid_ts) > 0:
second_per_grid_ts = torch.from_numpy(np.array(second_per_grid_ts)).float()
else:
second_per_grid_ts = torch.empty([0, ], dtype=torch.float32)
video_thw_grids = [thw_grids for s in samples for thw_grids in s.video_thw_grids]
if len(video_thw_grids) > 0:
video_thw_grids = torch.from_numpy(np.array(video_thw_grids)).long()
assert video_thw_grids.prod(dim=-1).sum() == videos.shape[0]
else:
video_thw_grids = torch.empty([0, 3], dtype=torch.long)
# If the user hasn't defined a target sequence length, then use the max along the sample lengths.
max_seq_len = self.seq_len
if not max_seq_len:
max_seq_len = max(len(s.text) for s in samples)
text_mat = np.full((len(samples), max_seq_len), self.tokenizer.pad_token_id, dtype=np.int64)
# +1 to accommodate shift to left by one later.
target_mat = np.full((len(samples), max_seq_len), self.tokenizer.pad_token_id, dtype=np.int64)
image_input_masks = np.zeros_like(text_mat, dtype=bool)
video_input_masks = np.zeros_like(text_mat, dtype=bool)
for i, s in enumerate(samples):
# If the sample/target length exceeds the target sequence length, then truncate.
text_len = min(max_seq_len, len(s.text))
target_len = min(max_seq_len, len(s.target))
text_mat[i, :text_len] = np.array(s.text)[:text_len]
# NOTE: we should assert user input sequence will not be truncated
if s.image_input_mask is not None:
image_input_masks[i, :text_len] = np.array(s.image_input_mask)[:text_len]
if s.video_input_mask is not None:
video_input_masks[i, :text_len] = np.array(s.video_input_mask)[:text_len]
target_mat[i, :target_len] = np.array(s.target)[:target_len]
batch = VQATaskBatch(
__keys__=[s.__key__ for s in samples],
__subflavors__=[s.__subflavors__ for s in samples],
imgs=imgs,
videos=videos,
image_thw_grids=image_thw_grids,
video_thw_grids=video_thw_grids,
second_per_grid_ts=second_per_grid_ts,
image_input_mask=torch.from_numpy(image_input_masks),
video_input_mask=torch.from_numpy(video_input_masks),
text=torch.from_numpy(text_mat),
target=torch.from_numpy(target_mat),
)
return batch
def encode_batch(self, batch: VQATaskBatch) -> dict:
raw = dataclasses.asdict(batch)
del raw["__subflavors__"]
return raw
def print_error_handler(exc: Exception, key: Optional[str], debug=False):
if not debug and isinstance(exc, InternalWarning):
return
print(
f"The following exception occurred in the dataloader for sample {key} and is skipped",
file=sys.stderr,
)
traceback.print_exc()