def stack()

in train.py [0:0]


def stack(X, X_emb, train, step=None, cache=None):
    with tf.name_scope('input_processing'):
        we = tf.get_variable(
            "we", [H.n_vocab, H.n_embd], dtype=H.dtype,
            initializer=random_or_zeros_init(stddev=H.w_embd_std))
        h = bs.embedding_lookup(we, X)
        H.we = we
        H.we_x = h
        h = embedding_dropout(h, train)

    h = add_position_embedding(h, X_emb, train, step=step)
    if step is None:
        h = tf.reshape(h, [H.n_batch, H.attn_ctx, H.n_embd])
    else:
        h = tf.reshape(h, [H.sample_batch, -1, H.n_embd])

    with tf.variable_scope('sos_token'):
        if H.num_self_gen_in_use > 0 and not H.use_unconditional_augmentation:
            y_gen_idx = 0
            sos_tok = 0
            for typ in H.self_gen_types:
                if not typ.is_used:
                    if mpi_rank() == 0:
                        print(f" [self-gen] not using {typ.description}")
                    continue
                if mpi_rank() == 0:
                    print(f" [self-gen] using {typ.description}")
                this_sos_var = tf.get_variable(
                    typ.sos_name,
                    [typ.num_tokens, H.n_embd],
                    dtype=H.dtype,
                    initializer=random_or_zeros_init(stddev=H.w_embd_std))
                this_sos_tok = bs.embedding_lookup(this_sos_var, H.Y_gen_ph[:, y_gen_idx:y_gen_idx + 1])
                assert this_sos_tok.shape[1:] == (1, H.n_embd)
                sos_tok += this_sos_tok
                y_gen_idx += 1
            assert y_gen_idx == H.num_self_gen_in_use
        else:
            sos = tf.get_variable(
                'sos', [1, 1, H.n_embd], dtype=H.dtype,
                initializer=random_or_zeros_init(stddev=H.w_embd_std))
            batch_size = H.n_batch if step is None else H.sample_batch
            sos_tok = tf.ones(shape=[batch_size, 1, H.n_embd], dtype=H.dtype) * sos
    if step is None:
        h = tf.concat([sos_tok, h[:, :-1, :]], axis=1)
        if H.randomly_determined_order_use_lookahead:
            print("lookahead_embd")
            with tf.variable_scope("lookahead_embedding"):
                h = add_position_embedding(h, X_emb, train, step=step)
    else:
        h = tf.concat([sos_tok, h], axis=1)[:, -1:, :]

    new_cache = []
    modes = H.attention_layers.split(',')
    assert H.n_layer % len(modes) == 0

    for layer_idx in range(H.n_layer):
        mode = modes[layer_idx % len(modes)]
        name = f'h{layer_idx}'
        if cache is not None:
            # We only cache the pre qkv tensor, as it takes up
            # too much memory otherwise on long sequences.
            h = tf.concat([cache[layer_idx], h], axis=1)
            new_cache.append(h)
            use_cache = True
        else:
            use_cache = False

        with tf.variable_scope(name):
            recompute = H.recompute and train
            if H.float16 and H.blocksparse_op and not use_cache:
                h = sparse_attention(h, H.n_head, mode, use_cache=use_cache,
                                     train=train, recompute=recompute)
            else:
                h = dense_attention(h, H.n_head, mode, use_cache=use_cache,
                                    train=train, recompute=recompute)

    if cache is not None:
        return h, new_cache

    return h