def model()

in model.py [0:0]


def model(sess, hps, train_iterator, test_iterator, data_init):

    # Only for decoding/init, rest use iterators directly
    with tf.name_scope('input'):
        X = tf.placeholder(
            tf.uint8, [None, hps.image_size, hps.image_size, 3], name='image')
        Y = tf.placeholder(tf.int32, [None], name='label')
        lr = tf.placeholder(tf.float32, None, name='learning_rate')

    encoder, decoder = codec(hps)
    hps.n_bins = 2. ** hps.n_bits_x

    def preprocess(x):
        x = tf.cast(x, 'float32')
        if hps.n_bits_x < 8:
            x = tf.floor(x / 2 ** (8 - hps.n_bits_x))
        x = x / hps.n_bins - .5
        return x

    def postprocess(x):
        return tf.cast(tf.clip_by_value(tf.floor((x + .5)*hps.n_bins)*(256./hps.n_bins), 0, 255), 'uint8')

    def _f_loss(x, y, is_training, reuse=False):

        with tf.variable_scope('model', reuse=reuse):
            y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')

            # Discrete -> Continuous
            objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
            z = preprocess(x)
            z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins)
            objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])

            # Encode
            z = Z.squeeze2d(z, 2)  # > 16x16x12
            z, objective, _ = encoder(z, objective)

            # Prior
            hps.top_shape = Z.int_shape(z)[1:]
            logp, _, _ = prior("prior", y_onehot, hps)
            objective += logp(z)

            # Generative loss
            nobj = - objective
            bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int(
                x.get_shape()[2]) * int(x.get_shape()[3]))  # bits per subpixel

            # Predictive loss
            if hps.weight_y > 0 and hps.ycond:

                # Classification loss
                h_y = tf.reduce_mean(z, axis=[1, 2])
                y_logits = Z.linear_zeros("classifier", h_y, hps.n_y)
                bits_y = tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=y_onehot, logits=y_logits) / np.log(2.)

                # Classification accuracy
                y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32)
                classification_error = 1 - \
                    tf.cast(tf.equal(y_predicted, y), tf.float32)
            else:
                bits_y = tf.zeros_like(bits_x)
                classification_error = tf.ones_like(bits_x)

        return bits_x, bits_y, classification_error

    def f_loss(iterator, is_training, reuse=False):
        if hps.direct_iterator and iterator is not None:
            x, y = iterator.get_next()
        else:
            x, y = X, Y

        bits_x, bits_y, pred_loss = _f_loss(x, y, is_training, reuse)
        local_loss = bits_x + hps.weight_y * bits_y
        stats = [local_loss, bits_x, bits_y, pred_loss]
        global_stats = Z.allreduce_mean(
            tf.stack([tf.reduce_mean(i) for i in stats]))

        return tf.reduce_mean(local_loss), global_stats

    feeds = {'x': X, 'y': Y}
    m = abstract_model_xy(sess, hps, feeds, train_iterator,
                          test_iterator, data_init, lr, f_loss)

    # === Sampling function
    def f_sample(y, eps_std):
        with tf.variable_scope('model', reuse=True):
            y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')

            _, sample, _ = prior("prior", y_onehot, hps)
            z = sample(eps_std=eps_std)
            z = decoder(z, eps_std=eps_std)
            z = Z.unsqueeze2d(z, 2)  # 8x8x12 -> 16x16x3
            x = postprocess(z)

        return x

    m.eps_std = tf.placeholder(tf.float32, [None], name='eps_std')
    x_sampled = f_sample(Y, m.eps_std)

    def sample(_y, _eps_std):
        return m.sess.run(x_sampled, {Y: _y, m.eps_std: _eps_std})
    m.sample = sample

    if hps.inference:
        # === Encoder-Decoder functions
        def f_encode(x, y, reuse=True):
            with tf.variable_scope('model', reuse=reuse):
                y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')

                # Discrete -> Continuous
                objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
                z = preprocess(x)
                z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins)
                objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])

                # Encode
                z = Z.squeeze2d(z, 2)  # > 16x16x12
                z, objective, eps = encoder(z, objective)

                # Prior
                hps.top_shape = Z.int_shape(z)[1:]
                logp, _, _eps = prior("prior", y_onehot, hps)
                objective += logp(z)
                eps.append(_eps(z))

            return eps

        def f_decode(y, eps, reuse=True):
            with tf.variable_scope('model', reuse=reuse):
                y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')

                _, sample, _ = prior("prior", y_onehot, hps)
                z = sample(eps=eps[-1])
                z = decoder(z, eps=eps[:-1])
                z = Z.unsqueeze2d(z, 2)  # 8x8x12 -> 16x16x3
                x = postprocess(z)

            return x

        enc_eps = f_encode(X, Y)
        dec_eps = []
        print(enc_eps)
        for i, _eps in enumerate(enc_eps):
            print(_eps)
            dec_eps.append(tf.placeholder(tf.float32, _eps.get_shape().as_list(), name="dec_eps_" + str(i)))
        dec_x = f_decode(Y, dec_eps)

        eps_shapes = [_eps.get_shape().as_list()[1:] for _eps in enc_eps]

        def flatten_eps(eps):
            # [BS, eps_size]
            return np.concatenate([np.reshape(e, (e.shape[0], -1)) for e in eps], axis=-1)

        def unflatten_eps(feps):
            index = 0
            eps = []
            bs = feps.shape[0]
            for shape in eps_shapes:
                eps.append(np.reshape(feps[:, index: index+np.prod(shape)], (bs, *shape)))
                index += np.prod(shape)
            return eps

        # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32)
        def encode(x, y):
            return flatten_eps(sess.run(enc_eps, {X: x, Y: y}))

        def decode(y, feps):
            eps = unflatten_eps(feps)
            feed_dict = {Y: y}
            for i in range(len(dec_eps)):
                feed_dict[dec_eps[i]] = eps[i]
            return sess.run(dec_x, feed_dict)

        m.encode = encode
        m.decode = decode

    return m