in src/open-r1-multimodal/src/open_r1/grpo_rec.py [0:0]
def __init__(self, data_path: str, script_args: GRPOScriptArguments):
super(LazySupervisedDataset, self).__init__()
self.script_args = script_args
self.list_data_dict = []
if data_path.endswith(".yaml"):
with open(data_path, "r") as file:
yaml_data = yaml.safe_load(file)
datasets = yaml_data.get("datasets")
# file should be in the format of:
# datasets:
# - json_path: xxxx1.json
# sampling_strategy: first:1000
# - json_path: xxxx2.json
# sampling_strategy: end:3000
# - json_path: xxxx3.json
# sampling_strategy: random:999
for data in datasets:
json_path = data.get("json_path")
sampling_strategy = data.get("sampling_strategy", "all")
sampling_number = None
if json_path.endswith(".jsonl"):
cur_data_dict = []
with open(json_path, "r") as json_file:
for line in json_file:
cur_data_dict.append(json.loads(line.strip()))
elif json_path.endswith(".json"):
with open(json_path, "r") as json_file:
cur_data_dict = json.load(json_file)
else:
raise ValueError(f"Unsupported file type: {json_path}")
if ":" in sampling_strategy:
sampling_strategy, sampling_number = sampling_strategy.split(":")
if "%" in sampling_number:
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
else:
sampling_number = int(sampling_number)
# Apply the sampling strategy
if sampling_strategy == "first" and sampling_number is not None:
cur_data_dict = cur_data_dict[:sampling_number]
elif sampling_strategy == "end" and sampling_number is not None:
cur_data_dict = cur_data_dict[-sampling_number:]
elif sampling_strategy == "random" and sampling_number is not None:
random.shuffle(cur_data_dict)
cur_data_dict = cur_data_dict[:sampling_number]
print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
self.list_data_dict.extend(cur_data_dict)
else:
raise ValueError(f"Unsupported file type: {data_path}")
# 新增分类索引
self.image_indices = []
self.text_indices = []
for idx in range(len(self.list_data_dict)):
if 'image' in self.list_data_dict[idx] and len(self.list_data_dict[idx]["image"]) > 0:
self.image_indices.append(idx)
elif 'image' in self.list_data_dict[idx]:
del self.list_data_dict[idx]['image']
self.text_indices.append(idx)
else:
self.text_indices.append(idx)