def cached_ds()

in distilvit/utils.py [0:0]


def cached_ds(cache_name):
    def _cached_ds(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if "args" in kwargs:
                cache_dir = kwargs["args"].cache_dir
                prune_cache = kwargs["args"].prune_cache
            else:
                cache_dir = ".cache"
                prune_cache = False

            cached_ds = os.path.join(cache_dir, cache_name)

            if not os.path.exists(cache_dir):
                os.makedirs(cache_dir)
            elif os.path.exists(cached_ds):
                if prune_cache:
                    print("Pruning cache...")
                    shutil.rmtree(cached_ds)

                else:
                    from datasets import load_from_disk

                    return load_from_disk(cached_ds)

            ds = func(*args, **kwargs)
            ds.save_to_disk(cached_ds)
            return ds

        return wrapper

    return _cached_ds