in train.py [0:0]
def sample_model():
X = tf.zeros(shape=[H.sample_batch, 0], dtype=tf.int32)
current_step = tf.constant(0, dtype=tf.int64)
accumulated_output = X[:, :current_step] # Everything up til now.
current_input = X[:, current_step - 1:current_step]
cache_vars = [tf.zeros(shape=[H.sample_batch, 0, H.n_embd],
dtype=H.dtype) for _ in range(H.n_layer)]
cacheshapes = [tf.TensorShape([H.sample_batch, None, H.n_embd])
for _ in range(H.n_layer)]
embd_index = tf.constant([0] * H.sample_batch, dtype=tf.int32)
first_embd = tf.zeros(shape=[H.sample_batch, H.emb_number, 0],
dtype=tf.int32)
loop_vars = [current_step, accumulated_output, current_input,
first_embd, embd_index, cache_vars]
shape_invariants = [current_step.get_shape(),
tf.TensorShape([H.sample_batch, None]),
tf.TensorShape([H.sample_batch, None]),
tf.TensorShape([H.sample_batch, H.emb_number, None]),
embd_index.get_shape(),
cacheshapes]
embd_shapes = tf.constant(H.emb_vocabs, dtype=tf.int32)
def cond(step, acc, curr, curr_embd, embd_index, cache):
return step < H.attn_ctx
def body(step, acc, curr, curr_embd, embd_index, cache):
with tf.variable_scope('model', custom_getter=f32_storage_getter):
h, cache = stack(curr, curr_embd, train=False, step=step,
cache=cache)
h = norm('final_norm', h, epsilon=1e-6)
h = h[:, -1:, :]
logits = tf.cast(get_logits('gen_logits', h, H.n_vocab), tf.float32)
logits = tf.reshape(logits, [H.sample_batch, H.n_vocab])
temp = H.temperature
symbol = tf.cast(tf.multinomial(logits / temp, 1), tf.int32)
with tf.device('/cpu:0'):
next_embd = tf.unravel_index(embd_index, embd_shapes)
# unravel_index yields a embd_size, n_batch tensor
next_embd = tf.transpose(next_embd, [1, 0])
next_embd = tf.reshape(next_embd, [
H.sample_batch, H.emb_number, 1])
next_index = embd_index + 1
return (step + 1, tf.concat([acc, symbol], axis=1), symbol,
next_embd, next_index, cache)
_, output_seq, _, _, _, _ = tf.while_loop(
cond=cond, body=body, loop_vars=loop_vars, back_prop=False,
shape_invariants=shape_invariants, parallel_iterations=1)
# Now, we want to gather the images across all ranks which have generated
# them. We will just allreduce a sparse tensor.
all_samples = [tf.zeros_like(output_seq) for _ in range(mpi_size())]
all_samples[mpi_rank()] = output_seq
all_samples = tf.cast(allreduce(tf.cast(
tf.concat(all_samples, axis=0), tf.float32)), tf.int32)
return all_samples