def model()

in train.py [0:0]


def model(train=False):
    with tf.variable_scope('model', custom_getter=f32_storage_getter):
        network_input = H.X_ph
        network_target = H.X_ph
        if H.rand_augment and train:
            assert network_input.shape[-1] == 3072, 'TODO: support other image sizes'
            network_input = tf.reshape(tf.cast(network_input, tf.uint8), [-1, 32, 32, 3])
            if H.rand_augment_conditioning:
                if H.use_unconditional_augmentation:
                    raise NotImplementedError
                rand_augment_idx = [t.sos_name for t in H.self_gen_types if t.is_used].index('sos_aa')
                batch = []
                with tf.device('/cpu:0'):
                    for i in range(H.n_batch):
                        example = network_input[i]
                        with_randaug = distort_image_with_randaugment(example, H.rand_augment_n, H.rand_augment_m)
                        without_randaug = example
                        should_autoaugment = tf.cast(H.Y_gen_ph[i, rand_augment_idx], tf.bool)
                        example = tf.cond(should_autoaugment, lambda: with_randaug, lambda: without_randaug)
                        batch.append(example)
                network_input = batch
            else:
                with tf.device('/cpu:0'):
                    network_input = [distort_image_with_randaugment(network_input[i], H.rand_augment_n, H.rand_augment_m) for i in range(H.n_batch)]
            network_input = tf.cast(tf.reshape(tf.concat(network_input, axis=0), [-1, 3072]), H.X_ph.dtype)
            network_target = network_input

        h = stack(network_input, H.X_emb_ph, train=train)
        h = norm('final_norm', h, epsilon=1e-6)
        targets = network_target
        gen_logits = get_logits('gen_logits', h, H.n_vocab, train=train)
        gen_loss, gen_losses = get_losses(gen_logits, targets)
        return gen_loss, gen_losses