megatron_patch/data/energon/chatml.py (46 lines of code) (raw):

# NOTE: add a license import warnings import pickle import torch import re from dataclasses import dataclass from typing import List, Union from webdataset.autodecode import Decoder, imagehandler from megatron.energon.epathlib.epath import EPath from megatron.energon.flavors.base_dataset import Sample from megatron.energon.flavors.webdataset import DefaultDecoderWebdatasetFactory @dataclass class ChatMLSample(Sample): """multi-turn complex samples with images and videos""" imgs: List[torch.Tensor] videos: List[List[torch.Tensor]] conversation: str # JSON string of GPT-format conversations class NestedImagesHandler: def __init__(self, imagespec): """Create an image handler. :param imagespec: short string indicating the type of decoding """ self.extensions = ['jpgs', 'videos'] self.extensions_mapping = { "jpgs": "jpg", "videos": "jpg" } self.image_handler = imagehandler(imagespec) def __call__(self, key, data): """Perform nested image decoding. :param key: file name extension :param data: binary data """ extension = re.sub(r".*[.]", "", key) if extension.lower() not in self.extensions: return None data = pickle.loads(data) key = self.extensions_mapping[extension] if extension.lower() == 'jpgs': data = [self.image_handler(key, d) for d in data] else: data = [[self.image_handler(key, d) for d in video] for video in data] return data class ChatMLWebdataset(DefaultDecoderWebdatasetFactory[ChatMLSample]): __sample_type__ = ChatMLSample def __init__(self, path: EPath, *, auto_decode:bool =True, **kwargs): super().__init__(path, auto_decode=auto_decode, **kwargs) if auto_decode: self._decoder = Decoder( [ imagehandler(self.image_decode), NestedImagesHandler(self.image_decode), self._video_decoder, ] )