in scripts/pre_encode.py [0:0]
def distribute_shards(start_shard_all, end_shard_all, slurm_ntasks):
total_shards = end_shard_all - start_shard_all + 1
shards_per_task = total_shards // slurm_ntasks
shards_per_task = [shards_per_task] * slurm_ntasks
# to distribute the remainder of tasks for non-evenly divisible number of shards
left_over_shards = total_shards % slurm_ntasks
for slurm_procid in range(left_over_shards):
shards_per_task[slurm_procid] += 1
assert sum(shards_per_task) == total_shards
distributed_shards = []
for slurm_procid in range(len(shards_per_task)):
if slurm_procid == 0:
start_shard = start_shard_all
else:
start_shard = distributed_shards[slurm_procid - 1][1] + 1
end_shard = start_shard + shards_per_task[slurm_procid] - 1
distributed_shards.append((start_shard, end_shard))
assert sum([end_shard - start_shard + 1 for start_shard, end_shard in distributed_shards]) == total_shards
return distributed_shards