def sample_model()

in train.py [0:0]


def sample_model():
    X = tf.zeros(shape=[H.sample_batch, 0], dtype=tf.int32)
    current_step = tf.constant(0, dtype=tf.int64)
    accumulated_output = X[:, :current_step]   # Everything up til now.
    current_input = X[:, current_step - 1:current_step]
    cache_vars = [tf.zeros(shape=[H.sample_batch, 0, H.n_embd],
                           dtype=H.dtype) for _ in range(H.n_layer)]
    cacheshapes = [tf.TensorShape([H.sample_batch, None, H.n_embd])
                   for _ in range(H.n_layer)]
    embd_index = tf.constant([0] * H.sample_batch, dtype=tf.int32)
    first_embd = tf.zeros(shape=[H.sample_batch, H.emb_number, 0],
                          dtype=tf.int32)

    loop_vars = [current_step, accumulated_output, current_input,
                 first_embd, embd_index, cache_vars]
    shape_invariants = [current_step.get_shape(),
                        tf.TensorShape([H.sample_batch, None]),
                        tf.TensorShape([H.sample_batch, None]),
                        tf.TensorShape([H.sample_batch, H.emb_number, None]),
                        embd_index.get_shape(),
                        cacheshapes]
    embd_shapes = tf.constant(H.emb_vocabs, dtype=tf.int32)

    def cond(step, acc, curr, curr_embd, embd_index, cache):
        return step < H.attn_ctx

    def body(step, acc, curr, curr_embd, embd_index, cache):
        with tf.variable_scope('model', custom_getter=f32_storage_getter):
            h, cache = stack(curr, curr_embd, train=False, step=step,
                             cache=cache)
            h = norm('final_norm', h, epsilon=1e-6)

            h = h[:, -1:, :]
            logits = tf.cast(get_logits('gen_logits', h, H.n_vocab), tf.float32)
            logits = tf.reshape(logits, [H.sample_batch, H.n_vocab])
            temp = H.temperature
            symbol = tf.cast(tf.multinomial(logits / temp, 1), tf.int32)
            with tf.device('/cpu:0'):
                next_embd = tf.unravel_index(embd_index, embd_shapes)
                # unravel_index yields a embd_size, n_batch tensor
                next_embd = tf.transpose(next_embd, [1, 0])
                next_embd = tf.reshape(next_embd, [
                    H.sample_batch, H.emb_number, 1])
                next_index = embd_index + 1
        return (step + 1, tf.concat([acc, symbol], axis=1), symbol,
                next_embd, next_index, cache)

    _, output_seq, _, _, _, _ = tf.while_loop(
        cond=cond, body=body, loop_vars=loop_vars, back_prop=False,
        shape_invariants=shape_invariants, parallel_iterations=1)

    # Now, we want to gather the images across all ranks which have generated
    # them. We will just allreduce a sparse tensor.
    all_samples = [tf.zeros_like(output_seq) for _ in range(mpi_size())]
    all_samples[mpi_rank()] = output_seq
    all_samples = tf.cast(allreduce(tf.cast(
        tf.concat(all_samples, axis=0), tf.float32)), tf.int32)
    return all_samples