in tools/sweep_setup.py [0:0]
def sweep_setup():
"""Samples cfgs for the sweep."""
setup_cfg = sweep_cfg.SETUP
# Create output directories
sweep_dir = os.path.join(sweep_cfg.ROOT_DIR, sweep_cfg.NAME)
cfgs_dir = os.path.join(sweep_dir, "cfgs")
logs_dir = os.path.join(sweep_dir, "logs")
print("Sweep directory is: {}".format(sweep_dir))
assert not os.path.exists(logs_dir), "Sweep already started: " + sweep_dir
if os.path.exists(logs_dir) or os.path.exists(cfgs_dir):
print("Overwriting sweep which has not yet launched")
os.makedirs(sweep_dir, exist_ok=True)
os.makedirs(cfgs_dir, exist_ok=True)
# Dump the original sweep_cfg
sweep_cfg_file = os.path.join(sweep_dir, "sweep_cfg.yaml")
os.system("cp {} {}".format(sweep_cfg.SWEEP_CFG_FILE, sweep_cfg_file))
# Create worker pool for sampling and saving configs
n_proc, chunk = sweep_cfg.NUM_PROC, setup_cfg.CHUNK_SIZE
process_pool = multiprocessing.Pool(n_proc)
# Fix random number generator seed and generate per chunk seeds
np.random.seed(setup_cfg.RNG_SEED)
n_chunks = int(np.ceil(setup_cfg.NUM_SAMPLES / chunk))
chunk_seeds = np.random.choice(1000000, size=n_chunks, replace=False)
# Sample configs in chunks using multiple workers each with a unique seed
info_str = "Number configs sampled: {}, configs kept: {} [t={:.2f}s]"
n_samples, n_cfgs, i, cfgs, timer = 0, 0, 0, {}, Timer()
while n_samples < setup_cfg.NUM_SAMPLES and n_cfgs < setup_cfg.NUM_CONFIGS:
timer.tic()
seeds = chunk_seeds[i * n_proc : i * n_proc + n_proc]
cfgs_all = process_pool.map(sample_cfgs, seeds)
cfgs = dict(cfgs, **{k: v for d in cfgs_all for k, v in d.items()})
n_samples, n_cfgs, i = n_samples + chunk * n_proc, len(cfgs), i + 1
timer.toc()
print(info_str.format(n_samples, n_cfgs, timer.total_time))
# Randomize cfgs order and subsample if oversampled
keys, cfgs = list(cfgs.keys()), list(cfgs.values())
n_cfgs = min(n_cfgs, setup_cfg.NUM_CONFIGS)
ids = np.random.choice(len(cfgs), n_cfgs, replace=False)
keys, cfgs = [keys[i] for i in ids], [cfgs[i] for i in ids]
# Save the cfgs and a cfgs_summary
timer.tic()
cfg_names = ["{:06}.yaml".format(i) for i in range(n_cfgs)]
cfgs_summary = {cfg_name: key for cfg_name, key in zip(cfg_names, keys)}
with open(os.path.join(sweep_dir, "cfgs_summary.yaml"), "w") as f:
yaml.dump(cfgs_summary, f, width=float("inf"))
cfg_files = [os.path.join(cfgs_dir, cfg_name) for cfg_name in cfg_names]
process_pool.starmap(dump_cfg, zip(cfg_files, cfgs))
timer.toc()
print(info_str.format(n_samples, n_cfgs, timer.total_time))