in scripts/run_prediction.py [0:0]
def run_prediction(sweep_dir, force):
sweep_path = Path(sweep_dir)
clustering_path = sweep_path / 'clustering.pkl'
models_paths = filter(utils.is_model, sweep_path.iterdir())
models_paths = list(filter(utils.train_done, models_paths))
prediction_cfg = OmegaConf.create(PREDICTION_CFG)
root = os.getenv('DATASETS_ROOT')
name = 'ImageNat'
datasets = load_datasets(root=root,
name=name,
clustering_path=clustering_path
)
executor = utils.get_slurm_executor(copy.deepcopy(prediction_cfg.slurm),
log_folder=str(sweep_path / 'logs' / 'run_prediction'))
# Constructing jobs
jobs, paths = [], []
with executor.batch():
# Construcing jobs
for model_path in models_paths:
if utils.prediction_done(model_path) and not force:
print(f'{model_path} is done. Skipping.')
continue
if (model_path / 'train_cfg.yaml').is_file():
train_cfg = utils.load_cfg(model_path / 'train_cfg.yaml')
elif (model_path / 'cfg_rank_0.yaml').is_file():
train_cfg = utils.load_cfg(model_path / 'cfg_rank_0.yaml')
else:
err_msg = 'train config not found'
raise ValueError(err_msg)
Algorithm = utils.load_model_cls(train_cfg)
worker_args = (
prediction_cfg,
train_cfg,
Algorithm,
datasets['train']['in'],
datasets['val']['in'])
worker = workers.Predictor()
job = executor.submit(worker, *worker_args)
jobs += [job]
paths += [model_path]
utils.write_trace('prediction.pending', dir_=str(model_path))
beholder = utils.Beholder(list(zip(jobs, paths)), stem='prediction')
beholder.start()
finished_jobs, jobs = utils.handle_jobs(jobs)
# Collecting results
jobs_results = [job.results() for job in finished_jobs]
return jobs_results