sweep.py (32 lines of code) (raw):
import os
import subprocess
import sys
from typing import List, Union
import fire
def main(model_sizes: Union[List[str], str], **kwargs):
if isinstance(model_sizes, str):
model_sizes = model_sizes.split(",")
assert (
"weak_model_size" not in kwargs
and "model_size" not in kwargs
and "weak_labels_path" not in kwargs
), "Need to use model_sizes when using sweep.py"
basic_args = [sys.executable, os.path.join(os.path.dirname(__file__), "train_simple.py")]
for key, value in kwargs.items():
basic_args.extend([f"--{key}", str(value)])
print("Running ground truth models")
for model_size in model_sizes:
subprocess.run(basic_args + ["--model_size", model_size], check=True)
print("Running transfer models")
for i in range(len(model_sizes)):
for j in range(i, len(model_sizes)):
weak_model_size = model_sizes[i]
strong_model_size = model_sizes[j]
print(f"Running weak {weak_model_size} to strong {strong_model_size}")
subprocess.run(
basic_args
+ ["--weak_model_size", weak_model_size, "--model_size", strong_model_size],
check=True,
)
if __name__ == "__main__":
fire.Fire(main)