in train.py [0:0]
def stack(X, X_emb, train, step=None, cache=None):
with tf.name_scope('input_processing'):
we = tf.get_variable(
"we", [H.n_vocab, H.n_embd], dtype=H.dtype,
initializer=random_or_zeros_init(stddev=H.w_embd_std))
h = bs.embedding_lookup(we, X)
H.we = we
H.we_x = h
h = embedding_dropout(h, train)
h = add_position_embedding(h, X_emb, train, step=step)
if step is None:
h = tf.reshape(h, [H.n_batch, H.attn_ctx, H.n_embd])
else:
h = tf.reshape(h, [H.sample_batch, -1, H.n_embd])
with tf.variable_scope('sos_token'):
if H.num_self_gen_in_use > 0 and not H.use_unconditional_augmentation:
y_gen_idx = 0
sos_tok = 0
for typ in H.self_gen_types:
if not typ.is_used:
if mpi_rank() == 0:
print(f" [self-gen] not using {typ.description}")
continue
if mpi_rank() == 0:
print(f" [self-gen] using {typ.description}")
this_sos_var = tf.get_variable(
typ.sos_name,
[typ.num_tokens, H.n_embd],
dtype=H.dtype,
initializer=random_or_zeros_init(stddev=H.w_embd_std))
this_sos_tok = bs.embedding_lookup(this_sos_var, H.Y_gen_ph[:, y_gen_idx:y_gen_idx + 1])
assert this_sos_tok.shape[1:] == (1, H.n_embd)
sos_tok += this_sos_tok
y_gen_idx += 1
assert y_gen_idx == H.num_self_gen_in_use
else:
sos = tf.get_variable(
'sos', [1, 1, H.n_embd], dtype=H.dtype,
initializer=random_or_zeros_init(stddev=H.w_embd_std))
batch_size = H.n_batch if step is None else H.sample_batch
sos_tok = tf.ones(shape=[batch_size, 1, H.n_embd], dtype=H.dtype) * sos
if step is None:
h = tf.concat([sos_tok, h[:, :-1, :]], axis=1)
if H.randomly_determined_order_use_lookahead:
print("lookahead_embd")
with tf.variable_scope("lookahead_embedding"):
h = add_position_embedding(h, X_emb, train, step=step)
else:
h = tf.concat([sos_tok, h], axis=1)[:, -1:, :]
new_cache = []
modes = H.attention_layers.split(',')
assert H.n_layer % len(modes) == 0
for layer_idx in range(H.n_layer):
mode = modes[layer_idx % len(modes)]
name = f'h{layer_idx}'
if cache is not None:
# We only cache the pre qkv tensor, as it takes up
# too much memory otherwise on long sequences.
h = tf.concat([cache[layer_idx], h], axis=1)
new_cache.append(h)
use_cache = True
else:
use_cache = False
with tf.variable_scope(name):
recompute = H.recompute and train
if H.float16 and H.blocksparse_op and not use_cache:
h = sparse_attention(h, H.n_head, mode, use_cache=use_cache,
train=train, recompute=recompute)
else:
h = dense_attention(h, H.n_head, mode, use_cache=use_cache,
train=train, recompute=recompute)
if cache is not None:
return h, new_cache
return h