in train.py [0:0]
def model(train=False):
with tf.variable_scope('model', custom_getter=f32_storage_getter):
network_input = H.X_ph
network_target = H.X_ph
if H.rand_augment and train:
assert network_input.shape[-1] == 3072, 'TODO: support other image sizes'
network_input = tf.reshape(tf.cast(network_input, tf.uint8), [-1, 32, 32, 3])
if H.rand_augment_conditioning:
if H.use_unconditional_augmentation:
raise NotImplementedError
rand_augment_idx = [t.sos_name for t in H.self_gen_types if t.is_used].index('sos_aa')
batch = []
with tf.device('/cpu:0'):
for i in range(H.n_batch):
example = network_input[i]
with_randaug = distort_image_with_randaugment(example, H.rand_augment_n, H.rand_augment_m)
without_randaug = example
should_autoaugment = tf.cast(H.Y_gen_ph[i, rand_augment_idx], tf.bool)
example = tf.cond(should_autoaugment, lambda: with_randaug, lambda: without_randaug)
batch.append(example)
network_input = batch
else:
with tf.device('/cpu:0'):
network_input = [distort_image_with_randaugment(network_input[i], H.rand_augment_n, H.rand_augment_m) for i in range(H.n_batch)]
network_input = tf.cast(tf.reshape(tf.concat(network_input, axis=0), [-1, 3072]), H.X_ph.dtype)
network_target = network_input
h = stack(network_input, H.X_emb_ph, train=train)
h = norm('final_norm', h, epsilon=1e-6)
targets = network_target
gen_logits = get_logits('gen_logits', h, H.n_vocab, train=train)
gen_loss, gen_losses = get_losses(gen_logits, targets)
return gen_loss, gen_losses