def get_dataset_size()

in src/chug/wds/dataset_info.py [0:0]


def get_dataset_size(shards):
    shardlist, _ = expand_urls(shards)
    dir_path = os.path.dirname(shardlist[0])

    sizes_filename = os.path.join(dir_path, 'sizes.json')
    len_filename = os.path.join(dir_path, '__len__')

    if os.path.exists(sizes_filename):
        sizes = json.load(open(sizes_filename, 'r'))
        total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shardlist])
    elif os.path.exists(len_filename):
        total_size = ast.literal_eval(open(len_filename, 'r').read())
    else:
        total_size = None  # num samples undefined

    num_shards = len(shardlist)

    return total_size, num_shards