def main()

in domainbed_measures/compute_gen_correlations.py [0:0]


def main(args):

    if args.run_dir == "":
        raise ValueError(
            "Please provide a working directory for storing generalization measure values."
        )

    MODEL_FILTERS = {
        "algorithm": args.algorithm,
        "dataset": args.dataset,
        "status": "done"
    }

    models_and_info = pd.read_csv(
        args.job_done_file,
        delimiter=" ",
        names=["path", "algorithm", "dataset", "status"])

    for filter, value in MODEL_FILTERS.items():
        models_and_info = models_and_info[models_and_info[filter] == value]

    if args.job_str == '' and args.debug == True:
        job_str = 'debug'
    else:
        job_str = args.job_str

    out_folder = os.path.join(
        args.run_dir,
        args.job_done_file.rstrip('.txt').split("/")[-1], "%s_%s_%s_%s" %
        (job_str, args.dirty_ood_split, args.algorithm, args.dataset))

    logging.info(f"Using directory {out_folder} for storing runs")

    if args.device == "cuda":
        gpus_per_node = 1
    else:
        gpus_per_node = 0

    if args.measures == "":
        measures_to_compute = MeasureRegistry._VALID_MEASURES
    elif '.json' in args.measures:
        logging.info('Using measure list file %s' % (args.measures))
        with open(args.measures, 'r') as f:
            measures_to_compute = json.load(f)
    else:
        measures_to_compute = args.measures.split(",")

    if args.calc_variance == True:
        experiment_to_use = VarianceExperiment(
            dirty_ood_split=args.dirty_ood_split, )
    else:
        experiment_to_use = Experiment(dirty_ood_split=args.dirty_ood_split, )

    jobs = []

    model_paths = list(models_and_info["path"])
    if args.debug_model_path != "":
        model_paths = [args.debug_model_path]

    if args.all_measures_one_job:
        all_jobs = list(model_paths)
    else:
        all_jobs = list(itertools.product(model_paths, measures_to_compute))

    current_idx = 0
    current_jobs_in_array = 0

    # Set random seed for file directory names
    random.seed(_RANDOM_SEED)
    # Ensure we never place more jobs in a job array than can be run concurrently
    while (current_idx < len(all_jobs)):

        if args.max_num_jobs != -1 and current_idx >= args.max_num_jobs:
            break

        job_path = os.path.join(
            out_folder, 'slurm_files', ''.join(
                random.choices(string.ascii_lowercase + string.digits, k=10)))
        logging.info(f"Launching jobs with path {job_path}")

        ex = submitit.AutoExecutor(
            job_path,
        )

        if args.slurm_partition != "":
            ex.update_parameters(
                slurm_partition=args.slurm_partition,
                gpus_per_node=gpus_per_node,
                cpus_per_task=4,
                nodes=1,
                timeout_min=args.slurm_timeout_min,
                slurm_mem=_DATASET_TO_MEMORY[args.dataset],
            )

        with ex.batch():
            for idx in range(current_idx, len(all_jobs)):
                if args.max_num_jobs != -1 and idx >= args.max_num_jobs:
                    break
                if args.all_measures_one_job == True:
                    path = all_jobs[idx]
                    measure = measures_to_compute
                else:
                    path, measure = all_jobs[idx]

                if args.debug or args.slurm_partition == "":
                    experiment_to_use(path, measure, args.dataset, job_path)
                    if args.debug:
                        break
                else:
                    jobs.append(
                        ex.submit(experiment_to_use, path, measure,
                                  args.dataset, job_path))
                    current_jobs_in_array += 1

                if current_jobs_in_array >= MAX_JOBS_IN_ARRAY:
                    logging.info(f"Starting new job array..at {idx+1}")
                    current_idx = idx + 1
                    current_jobs_in_array = 0
                    break

                if len(jobs) > len(all_jobs):
                    raise ValueError

            current_idx = idx + 1
        if args.debug:
            break

    logging.info("Launching %d jobs for %s" % (len(jobs), args.dataset))

    start_time = time.time()
    while not all([j.done() for j in jobs]):
        time.sleep(_SLEEP_TIME)
        jobs_done = sum([j.done() for j in jobs])
        logging.info("%d/%d jobs done (%f sec per job)" %
                     (jobs_done, len(jobs),
                      (time.time() - start_time) / (jobs_done + 1)))

    _ = [j.result() for j in jobs]