in vision/smolvlm2/smolvlm/datasets/dataset.py [0:0]
def _get_item(self, idx: int) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[idx]
if isinstance(idx, int):
sources = [sources]
content_type = sources[0].get("type", self.modality).lower()
frames: List[Image.Image] = []
timestamps: List[str] = []
duration_seconds = None
if content_type == "video":
## load videos
#self.processor.image_processor.size = (self.video_target_size, self.video_target_size)
self.processor.image_processor.size = {"longest_edge": self.video_target_size}
self.processor.image_processor.do_resize = True
self.processor.image_processor.do_image_splitting = False
media = sources[0].get("video") or sources[0].get("image")
if media:
path = os.path.join(self.mm_path, media)
if os.path.isdir(path):
## TODO: can we simplify this logic??
frames, timestamps, duration_seconds = load_image_directory_as_frames(
folder_path=path,
source_fps=self.source_fps,
target_fps=self.target_fps,
max_frames=self.max_frames
)
else:
# I added skip secs, these are how meny seconds to skip at start/end of video before sampling frames. sometimes, these frames are very noisy so better to skip them.
#TODO: we should add this as data arg.
frames, timestamps, duration_seconds = load_video(
path,
max_frames=self.max_frames,
target_fps=self.target_fps,
skip_secs=1.0 # or data_args.skip_secs if you want
)
elif content_type == "image" or content_type == "multiimage":
## load images and multi-image
self.processor.image_processor.size = {"longest_edge": self.image_target_size}
self.processor.image_processor.do_resize = True
self.processor.image_processor.do_image_splitting = True
media = sources[0].get("image", False)
if media:
if isinstance(media, str):
media = [media]
paths = [os.path.join(self.mm_path, m) for m in media]
frames = [load_single_image(path) for path in paths]
else:
raise("No image found for sample")
else:
frames = None
conversations = copy.deepcopy([e["conversations"] for e in sources])
## get system message
system_message = DEFAULT_SYSTEM_MESSAGE
for k, v in sources[0].items():
if isinstance(k, str) and "system" in k.lower() and "message" in k.lower() and isinstance(v, str):
system_message = v
break
# Ensure each conversation has a system turn at index 0
for conv in conversations:
system_idx = next((i for i, t in enumerate(conv) if t.get("from", "").lower() == "system"), None)
if system_idx is not None:
# Move existing system turn to index 0
conv.insert(0, conv.pop(system_idx))
else:
# If no system turn, add one
conv.insert(0, {"from": "system", "value": system_message})
conversations = [[self._convert_llava_to_openai_format(turn) for turn in conversation] for conversation in conversations]
conversations = [self._replace_multimodal_tokens(conversation, content_type, frames, timestamps) for conversation in conversations]
if self.add_media_intro_outro:
for conv in conversations:
if content_type == "text":
continue
elif content_type == "image" or content_type == "multiimage":
if conv[1]['content'][0]['type'] == "image":
conv[1]['content'].insert(0, {'type': 'text', 'text': DEFAULT_IMAGE_INTRO})
elif content_type == "video":
if conv[1]['content'][0]['type'] == "image" or conv[1]['content'][0]['type'] == "text" and FRAME_TIMESTAMP_MESSAGE in conv[1]['content'][0]['text']:
#conv[1]['content'].insert(0, {'type': 'text', 'text': DEFAULT_VIDEO_INTRO})
conv[1]['content'].insert(0, {'type': 'text', 'text': DEFAULT_VIDEO_INTRO.format(frame_count=num2words(len(frames)), video_duration=str(datetime.timedelta(seconds=duration_seconds)))})
target_message_index = -1
last_image_index = -1
for i, message in enumerate(conv):
if 'content' in message:
for j, content in enumerate(message['content']):
if content.get('type') == 'image':
target_message_index = i
last_image_index = j
# If we found an image, insert the outro right after it in the content list
if target_message_index != -1 and last_image_index != -1:
conv[target_message_index]['content'].insert(last_image_index + 1,
{'type': 'text', 'text': DEFAULT_MEDIA_OUTTRO})
text_input = self.processor.apply_chat_template(conversations[0], add_generation_prompt=False)
encoded = self.processor(
text=text_input,
images=frames,
return_tensors="pt",
padding=False,
)
if encoded["input_ids"][0].size(0) > self.processor.tokenizer.model_max_length:
raise ValueError(f"Sequence length {encoded['input_ids'][0].size(0)} exceeds maximum {self.processor.tokenizer.model_max_length}")
# Each item is shape [1, seq_len]
input_ids = encoded["input_ids"][0]
attention_mask = encoded["attention_mask"][0]
# Start all labels as input_ids
labels = input_ids.clone()
self._mask_special_tokens(input_ids, labels)
if self.mask_system_tokens:
_mask_system_tokens(input_ids, labels, self.tokenizer)
if self.mask_user_tokens:
_mask_user_tokens(input_ids, labels, self.tokenizer)
out = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
if "pixel_values" in encoded:
out["pixel_values"] = encoded["pixel_values"][0]
return out