in distilvit/utils.py [0:0]
def cached_ds(cache_name):
def _cached_ds(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if "args" in kwargs:
cache_dir = kwargs["args"].cache_dir
prune_cache = kwargs["args"].prune_cache
else:
cache_dir = ".cache"
prune_cache = False
cached_ds = os.path.join(cache_dir, cache_name)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
elif os.path.exists(cached_ds):
if prune_cache:
print("Pruning cache...")
shutil.rmtree(cached_ds)
else:
from datasets import load_from_disk
return load_from_disk(cached_ds)
ds = func(*args, **kwargs)
ds.save_to_disk(cached_ds)
return ds
return wrapper
return _cached_ds