in scripts/super_res_sample.py [0:0]
def load_data_for_worker(base_samples, batch_size, class_cond):
with bf.BlobFile(base_samples, "rb") as f:
obj = np.load(f)
image_arr = obj["arr_0"]
if class_cond:
label_arr = obj["arr_1"]
rank = dist.get_rank()
num_ranks = dist.get_world_size()
buffer = []
label_buffer = []
while True:
for i in range(rank, len(image_arr), num_ranks):
buffer.append(image_arr[i])
if class_cond:
label_buffer.append(label_arr[i])
if len(buffer) == batch_size:
batch = th.from_numpy(np.stack(buffer)).float()
batch = batch / 127.5 - 1.0
batch = batch.permute(0, 3, 1, 2)
res = dict(low_res=batch)
if class_cond:
res["y"] = th.from_numpy(np.stack(label_buffer))
yield res
buffer, label_buffer = [], []