vision/smolvlm2/scripts/create_mixture.py (115 lines of code) (raw):
"""
Usage:
python update_yaml_sampling.py input.yaml --output_yaml updated.yaml \
--text_percent 14 --image_percent 50 --multiimage_percent 5 --video_percent 36 \
--short_video_factor 0.5 --mammoth_fraction 0.3
"""
import yaml
import argparse
import re
import sys
from collections import defaultdict
# -------------------------------------------------------------------
# Helper functions for sampling logic.
# -------------------------------------------------------------------
def always_full(dataset_entry):
"""
Return True if the dataset entry should always be sampled at 100%.
We check for keywords like "gpt", "visualwebinstruct", or "vision_flan".
"""
always_full_keywords = ["gpt", "visualwebinstruct", "vision_flan"]
for field in [dataset_entry.get("name", ""), dataset_entry.get("path", ""), dataset_entry.get("json_path", "")]:
if any(kw in field.lower() for kw in always_full_keywords):
return True
return False
def base_sampling_fraction(dataset_entry, text_percent, image_percent, multiimage_percent, video_percent, short_video_factor):
"""
Compute the overall target fraction (a float in [0,1]) for the dataset entry
based on its modality and video duration if applicable.
For video datasets, we look for a pattern like _X_Y_s in the name or json_path.
If found and if the upper bound (Y) is <= 60, then we treat it as a short video.
"""
modality = dataset_entry.get("modality", "").lower()
name = dataset_entry.get("name", "").lower()
json_path = dataset_entry.get("json_path", "").lower()
if modality == "text":
return text_percent / 100.0
elif modality == "video":
# Look for a pattern like _X_Y_s in the name or json_path.
pattern = r'_(\d+)_([\d]+)_s'
match = re.search(pattern, name)
if not match:
match = re.search(pattern, json_path)
if match:
upper_bound = int(match.group(2))
if upper_bound <= 60:
return (video_percent / 100.0) * short_video_factor
else:
return video_percent / 100.0
else:
# If no pattern is found, default to full video_percent.
return video_percent / 100.0
elif modality == "image":
return image_percent / 100.0
elif modality == "multiimage" or "multiimage" in name:
return multiimage_percent / 100.0
else:
# If modality is unknown, default to 100%.
return 1.0
def format_sampling(fraction):
"""
Convert a fraction (0 to 1) into a sampling_strategy string.
For example, 0.33 becomes "random:33.0%".
"""
percent_str = round(fraction * 100, 2)
return f"random:{percent_str}%"
# -------------------------------------------------------------------
# Functions for aligning datasets using the 'name' field.
# -------------------------------------------------------------------
def compute_aligned_name(dataset_entry):
"""
Compute an "aligned name" for the dataset entry by looking at its name.
We expect names to be of the form "mammoth:sharegpt" or "onevision:sharegpt".
We split on ":" and return the second part as the aligned name.
If no colon is found, we default to using the full name (lowercased and stripped).
"""
name = dataset_entry.get("name", "").strip()
if ':' in name:
# Assume the format is "source:aligned_name"
parts = name.split(":", 1)
return parts[1].strip().lower()
else:
return name.lower()
def is_onevision(dataset_entry):
"""
Return True if the dataset entry is from OneVision. We assume the name starts with "onevision:".
"""
name = dataset_entry.get("name", "").lower().strip()
return name.startswith("onevision:")
def is_mammoth(dataset_entry):
"""
Return True if the dataset entry is from MammothVL. We assume the name starts with "mammoth:".
"""
name = dataset_entry.get("name", "").lower().strip()
return name.startswith("mammoth:")
# -------------------------------------------------------------------
# YAML load/write functions.
# -------------------------------------------------------------------
def load_yaml(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def write_yaml(data, file_path):
with open(file_path, 'w', encoding='utf-8') as f:
yaml.safe_dump(data, f, sort_keys=False)
# -------------------------------------------------------------------
# Main update logic.
# -------------------------------------------------------------------
def update_sampling_strategies(data, text_percent, image_percent, multiimage_percent, video_percent, short_video_factor, mammoth_fraction):
"""
Update each dataset entry's sampling_strategy according to these rules:
- If always_full applies, use "random:100%".
- Otherwise, compute an overall target fraction using modality rules.
- For datasets that share the same aligned name (e.g. "sharegpt"), split the target fraction:
OneVision gets overall_fraction * (1 - mammoth_fraction)
MammothVL gets overall_fraction * mammoth_fraction
- For non-shared datasets, simply use the overall fraction.
"""
if "datasets" not in data or not isinstance(data["datasets"], list):
print("Error: YAML must contain a top-level 'datasets' list.")
sys.exit(1)
datasets = data["datasets"]
# Group dataset entries by aligned name.
groups = defaultdict(list)
for ds in datasets:
aligned = compute_aligned_name(ds)
groups[aligned].append(ds)
# Update each entry.
for aligned_name, group_entries in groups.items():
shared = len(group_entries) > 1
for ds in group_entries:
if always_full(ds):
ds["sampling_strategy"] = "random:100%"
continue
overall_fraction = base_sampling_fraction(ds, text_percent, image_percent, multiimage_percent, video_percent, short_video_factor)
if shared:
# If both OneVision and MammothVL versions exist for this aligned name,
# split the overall fraction.
if is_onevision(ds):
final_fraction = overall_fraction * (1 - mammoth_fraction)
elif is_mammoth(ds):
final_fraction = overall_fraction * mammoth_fraction
else:
# If the source is unclear, fall back to overall.
final_fraction = overall_fraction
else:
final_fraction = overall_fraction
ds["sampling_strategy"] = format_sampling(final_fraction)
return data
def main():
parser = argparse.ArgumentParser(description=(
"Update dataset sampling strategies in a YAML file. For datasets that are shared (i.e. have the same aligned name), "
"split the overall target between the OneVision and MammothVL versions based on --mammoth_fraction."
))
parser.add_argument("input_yaml", help="Path to input YAML file.")
parser.add_argument("--output_yaml", default="updated_datasets.yaml", help="Path to output YAML file.")
parser.add_argument("--text_percent", type=float, default=14.0, help="Target percentage for text datasets.")
parser.add_argument("--image_percent", type=float, default=50.0, help="Target percentage for image datasets.")
parser.add_argument("--multiimage_percent", type=float, default=5.0, help="Target percentage for multi-image datasets.")
parser.add_argument("--video_percent", type=float, default=36.0, help="Target percentage for video datasets.")
parser.add_argument("--short_video_factor", type=float, default=0.5,
help="Multiplier for video datasets with an upper duration bound <= 60 seconds (as determined by the _X_Y_s pattern).")
parser.add_argument("--mammoth_fraction", type=float, default=0.3,
help="Fraction of the overall shared dataset to assign to the MammothVL version (the remainder goes to OneVision).")
args = parser.parse_args()
data = load_yaml(args.input_yaml)
updated_data = update_sampling_strategies(
data,
text_percent=args.text_percent,
image_percent=args.image_percent,
multiimage_percent=args.multiimage_percent,
video_percent=args.video_percent,
short_video_factor=args.short_video_factor,
mammoth_fraction=args.mammoth_fraction
)
write_yaml(updated_data, args.output_yaml)
print(f"[INFO] Updated YAML written to {args.output_yaml}")
if __name__ == "__main__":
main()