in fairdiplomacy/models/diplomacy_model/train_sl.py [0:0]
def get_datasets_from_cfg(args):
"""Returns a 3-tuple (train_set, val_set, extra_val_sets)"""
cache = {}
def cached_torch_load(fpath):
if fpath not in cache:
cache[fpath] = torch.load(fpath)
return cache[fpath]
# search for data and create train/val splits
if args.data_cache and os.path.exists(args.data_cache):
logger.info(f"Found dataset cache at {args.data_cache}")
train_dataset, val_dataset = cached_torch_load(args.data_cache)
else:
dataset_params = args.dataset_params
assert args.metadata_path is not None
assert dataset_params.data_dir is not None
game_metadata, min_rating, train_game_ids, val_game_ids = get_sl_db_args(
args.metadata_path, args.min_rating_percentile, args.max_games, args.val_set_pct
)
dataset_params_dict = MessageToDict(dataset_params, preserving_proto_field_name=True)
train_dataset = Dataset(
game_ids=train_game_ids,
game_metadata=game_metadata,
min_rating=min_rating,
**dataset_params_dict,
)
train_dataset.preprocess()
val_dataset = Dataset(
game_ids=val_game_ids,
game_metadata=game_metadata,
min_rating=min_rating,
**dataset_params_dict,
)
val_dataset.preprocess()
if args.data_cache:
logger.info(f"Saving datasets to {args.data_cache}")
torch.save((train_dataset, val_dataset), args.data_cache)
logger.info(f"Train dataset: {train_dataset.stats_str()}")
logger.info(f"Val dataset: {val_dataset.stats_str()}")
# possibly append more data caches to train/val with various cfg args
train_dataset = [train_dataset]
val_dataset = [val_dataset]
# only t gets added
if args.extra_train_data_caches:
for path in args.extra_train_data_caches:
train_dataset.append(cached_torch_load(path)[0])
logger.info(f"Append train dataset: {train_dataset[-1].stats_str()}")
# t, v get added to their respective data sets
if args.glob_append_data_cache:
for path in glob.glob(args.glob_append_data_cache):
t, v = cached_torch_load(path)
train_dataset.append(t)
logger.info(f"Append train dataset: {train_dataset[-1].stats_str()}")
if v is not None:
val_dataset.append(v)
logger.info(f"Append val dataset: {val_dataset[-1].stats_str()}")
# both t, v get added to val set
if args.glob_append_data_cache_as_val:
for path in glob.glob(args.glob_append_data_cache_as_val):
t, v = cached_torch_load(path)
for x in [t, v]:
if x is not None:
val_dataset.append(x)
logger.info(f"Append val dataset: {val_dataset[-1].stats_str()}")
# concat datasets
train_dataset = (
Dataset.from_merge(train_dataset) if len(train_dataset) > 1 else train_dataset[0]
)
val_dataset = Dataset.from_merge(val_dataset) if len(val_dataset) > 1 else val_dataset[0]
logger.info(f"Final dataset lens: train={len(train_dataset)} val={len(val_dataset)}")
# extra val data caches, returned separately
extra_val_datasets = {}
for name, path in args.extra_val_data_caches.items():
extra_val_datasets[name] = cached_torch_load(path)[1]
logger.info(f"Extra val dataset ({name}): {extra_val_datasets[name].stats_str()}")
# Clear the cache.
cache = {}
return train_dataset, val_dataset, extra_val_datasets