in train.py [0:0]
def aug_eval(partition, epoch):
tx, ty = get_data(partition)
if H.aug_eval_n_examples is not None:
tx = tx[:H.aug_eval_n_examples]
if ty is not None:
ty = ty[:H.aug_eval_n_examples]
gen_in_use = [gen for gen in H.self_gen_types if gen.is_used]
if not gen_in_use:
gen_in_use = [AugmentationType("sos", "identity", 1, True, identity)]
aug_choices = [gen.num_tokens for gen in gen_in_use]
for aug_types in go_over(aug_choices):
fname = os.path.join(
H.model_dir,
f"{H.desc}_" + "_".join(map(str, aug_types)) + "_losses.npz")
if os.path.exists(fname):
if mpi_rank() == 0:
print(f" Evaluated {fname}")
continue
if mpi_rank() == 0:
print(f"Evaluating {fname}")
losses = []
imgs = []
for data in iter_data_mpi(tx, n_batch=H.n_batch, log=logprint,
split_by_rank=dataset.full_dataset_valid):
feeds = {H.X_ph: data[0], H.X_emb_ph: H.x_emb}
x_emb = np.concatenate([H.x_emb.copy() for _ in range(H.n_batch)], axis=0)
d_in = data[0]
if H.num_self_gen_in_use > 0:
y_gen_list = []
for aug_type, gen in zip(aug_types, gen_in_use):
if gen.sos_name == 'sos_data':
raise NotImplementedError("sos_data is not supported in aug_eval")
yy = np.full((H.n_batch, 1), aug_type, dtype=np.int32)
d_in, x_emb, y_gen = gen.fn(d_in, x_emb, yy=yy)
assert (y_gen == yy).all()
y_gen_list.append(y_gen)
feeds[H.X_ph] = d_in
if H.permute_embeddings:
feeds[H.X_emb_ph] = x_emb
if not H.use_unconditional_augmentation:
feeds[H.Y_gen_ph] = np.concatenate(y_gen_list, axis=1)
assert (feeds[H.Y_gen_ph] == np.stack([aug_types] * H.n_batch)).all()
imgs.append(d_in)
cur_loss = sess.run(H.eval_gen_losses, feeds)
assert cur_loss.shape[0] == H.n_batch
losses.append(cur_loss)
losses = np.concatenate(losses, axis=0).astype(np.float32)
assert losses.shape[0] == tx.shape[0] // mpi_size()
mpi_barrier()
losses = mpi_allgather(losses)
assert losses.shape[0] == tx.shape[0]
loss = losses.mean()
content = dict(epoch=epoch, aug_types=aug_types, loss=loss, bpd=loss / np.log(2.0))
logprint(**content)
content["losses"] = losses
if mpi_rank() == 0:
np.savez(fname, **content)
imgs = np.concatenate(imgs, axis=0)
assert imgs.shape[0] == tx.shape[0] // mpi_size()
mpi_barrier()
imgs = mpi_allgather(imgs)
assert imgs.shape == tx.shape
if mpi_rank() == 0 and partition != "test":
fname = os.path.join(H.model_dir, f"{H.desc}_" + "_".join(map(str, aug_types)) + "_imgs.npz")
np.savez(fname, imgs=imgs.reshape(dataset.orig_shape))
mpi_barrier()