def get_datasets_from_cfg()

in fairdiplomacy/models/diplomacy_model/train_sl.py [0:0]


def get_datasets_from_cfg(args):
    """Returns a 3-tuple (train_set, val_set, extra_val_sets)"""
    cache = {}

    def cached_torch_load(fpath):
        if fpath not in cache:
            cache[fpath] = torch.load(fpath)
        return cache[fpath]

    # search for data and create train/val splits
    if args.data_cache and os.path.exists(args.data_cache):
        logger.info(f"Found dataset cache at {args.data_cache}")
        train_dataset, val_dataset = cached_torch_load(args.data_cache)
    else:
        dataset_params = args.dataset_params
        assert args.metadata_path is not None
        assert dataset_params.data_dir is not None
        game_metadata, min_rating, train_game_ids, val_game_ids = get_sl_db_args(
            args.metadata_path, args.min_rating_percentile, args.max_games, args.val_set_pct
        )
        dataset_params_dict = MessageToDict(dataset_params, preserving_proto_field_name=True)

        train_dataset = Dataset(
            game_ids=train_game_ids,
            game_metadata=game_metadata,
            min_rating=min_rating,
            **dataset_params_dict,
        )
        train_dataset.preprocess()
        val_dataset = Dataset(
            game_ids=val_game_ids,
            game_metadata=game_metadata,
            min_rating=min_rating,
            **dataset_params_dict,
        )
        val_dataset.preprocess()

        if args.data_cache:
            logger.info(f"Saving datasets to {args.data_cache}")
            torch.save((train_dataset, val_dataset), args.data_cache)

    logger.info(f"Train dataset: {train_dataset.stats_str()}")
    logger.info(f"Val dataset: {val_dataset.stats_str()}")

    # possibly append more data caches to train/val with various cfg args
    train_dataset = [train_dataset]
    val_dataset = [val_dataset]

    # only t gets added
    if args.extra_train_data_caches:
        for path in args.extra_train_data_caches:
            train_dataset.append(cached_torch_load(path)[0])
            logger.info(f"Append train dataset: {train_dataset[-1].stats_str()}")

    # t, v get added to their respective data sets
    if args.glob_append_data_cache:
        for path in glob.glob(args.glob_append_data_cache):
            t, v = cached_torch_load(path)
            train_dataset.append(t)
            logger.info(f"Append train dataset: {train_dataset[-1].stats_str()}")
            if v is not None:
                val_dataset.append(v)
                logger.info(f"Append val dataset: {val_dataset[-1].stats_str()}")

    # both t, v get added to val set
    if args.glob_append_data_cache_as_val:
        for path in glob.glob(args.glob_append_data_cache_as_val):
            t, v = cached_torch_load(path)
            for x in [t, v]:
                if x is not None:
                    val_dataset.append(x)
                    logger.info(f"Append val dataset: {val_dataset[-1].stats_str()}")

    # concat datasets
    train_dataset = (
        Dataset.from_merge(train_dataset) if len(train_dataset) > 1 else train_dataset[0]
    )
    val_dataset = Dataset.from_merge(val_dataset) if len(val_dataset) > 1 else val_dataset[0]
    logger.info(f"Final dataset lens: train={len(train_dataset)} val={len(val_dataset)}")

    # extra val data caches, returned separately
    extra_val_datasets = {}
    for name, path in args.extra_val_data_caches.items():
        extra_val_datasets[name] = cached_torch_load(path)[1]
        logger.info(f"Extra val dataset ({name}): {extra_val_datasets[name].stats_str()}")

    # Clear the cache.
    cache = {}

    return train_dataset, val_dataset, extra_val_datasets