megatron_patch/data/json_sft.py (106 lines of code) (raw):

# Copyright (c) 2025 Alibaba PAI Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import io import copy import json import torch try: from megatron import get_args except: from megatron.training import get_args from datasets import load_dataset from tqdm import tqdm from megatron_patch.tokenizer import get_tokenizer class JSONSFTDataset(torch.utils.data.Dataset): """ Experimental: This dataset is aimed for SFT of arbitrary models with a default chat_template, but not tested on all cases. A class for processing a conversation dataset """ def __init__(self, path, max_padding_length, split='train'): super().__init__() self.tokenizer = get_tokenizer() assert hasattr(self.tokenizer, 'apply_chat_template'), \ "The SFT-Raw Dataset is valid for tokenizers with chat template, please provide a template." self.IGNORE_INDEX = self.tokenizer.pad_token_id self.eos_token_id = self.tokenizer.eos_token_id self.is_pad_token_eos_token = self.tokenizer.pad_token_id == self.eos_token_id self.max_padding_length = max_padding_length list_data_dict = load_dataset( 'json', data_files=path[0], split=split, ) train_dataset = list_data_dict.map( self.preprocess, batched=True, batch_size=3000, num_proc=16, remove_columns=list_data_dict.column_names, load_from_cache_file=False, desc="Running Encoding" ) self.input_ids = np.array(train_dataset['input_ids']) self.labels = np.array(train_dataset['labels']) self.samples = [] for inputs, labels in tqdm(zip(self.input_ids, self.labels)): self.samples.append([inputs, labels]) print(' >> total number of samples: {}'.format(len(self.samples))) def _make_r_io_base(self, f, mode: str): if not isinstance(f, io.IOBase): f = open(f, mode=mode, encoding='utf-8') return f def jload(self, f, mode='r'): """ Load a .json file into a dictionary. Args: f: The file object or string representing the file path. mode: The mode in which to open the file (e.g., 'r', 'w', 'a'). Returns: A dictionary containing the contents of the JSON file. """ f = self._make_r_io_base(f, mode) jdict = json.load(f) f.close() return jdict def __len__(self): return len(self.samples) def __getitem__(self, idx): raw_sample = self.samples[idx] return self.gpt_convert_example_to_feature(raw_sample) def preprocess(self, examples): """ Preprocess the data by tokenizing. Args: sources (List[str]): a list of source strings targets (List[str]): a list of target strings tokenizer (Tokenizer): a tokenizer object used for tokenization Returns: dict: a dictionary containing the input_ids and labels for the examples """ datas = [] if 'instruction' in examples: datas = [ examples['instruction']] elif 'query' in examples: datas = [ examples['query']] else: raise ValueError('Cannot find key instruction or query!') if 'input' in examples: datas.append(examples['input']) if 'output' in examples: datas.append(examples['output']) elif 'content' in examples: datas.append(examples['content']) elif 'response' in examples: datas.append(examples['response']) else: raise ValueError('Cannot find output key `output`, `content` or `response`!') input_ids = [] labels = [] for data in zip(*datas): text = [ {'role': 'user', 'content': ''.join(data[:-1])}, {'role': 'assistant', 'content': data[-1]} ] source = self.tokenizer.apply_chat_template(text[:-1]) full = self.tokenizer.apply_chat_template(text) for t1, t2 in zip(source, full): assert t1 == t2, "The user input_ids are not a prefix of the full input_ids! Please check the template." if len(source) >= self.max_padding_length: continue if len(full) > self.max_padding_length: full = full[:self.max_padding_length] elif self.is_pad_token_eos_token: assert full[-1] == self.eos_token_id, f"Assume any untruncated sample ends with <eos>! But got: {self.tokenizer.detokenize(full)}" full[-1] = - 1 - full[-1] if self.max_padding_length > len(full): full = full + [self.IGNORE_INDEX] * (self.max_padding_length - len(full)) # NOTE: in get_batch_on_this_tp_rank_original, tokens use [:-1] and labels use [1:] # we add an extra token to use the old api # TODO: update get_batch_on_this_tp_rank_original and replace the following line with # label = [self.IGNORE_INDEX] * (len(source) - 1) + full[len(input_ids):] + [self.IGNORE_INDEX] full = full + [self.IGNORE_INDEX] label = [self.IGNORE_INDEX] * len(source) + full[len(source):] input_ids.append(full) labels.append(label) return dict(input_ids=input_ids, labels=labels) def gpt_convert_example_to_feature(self, sample): """ Convert a single sample containing input_id, label and loss_mask into a format suitable for GPT training. """ input_ids, labels = sample train_sample = { 'input_ids': input_ids, 'labels': labels } return train_sample