in train.py [0:0]
def post_attention(x, a, use_cache=None, train=False, pdrop=None):
nx = x.shape[-1].value
a = linear('post_proj', a, nx,
std=np.sqrt(H.post_w * 0.5 / nx / H.n_layer))
scopename = tf.get_variable_scope().name
a = residual_dropout(a, train, key=f'{scopename}-a', pdrop=pdrop)
x = x[:, -1:, :] if use_cache else x
x = bs.add(x, a)
inner_dim = int(nx * H.mlp_multiple)
m = norm("mlp", x)
m = linear('mlp_proj1', m, inner_dim,
std=np.sqrt(H.mlp_w1 / nx), fast_gelu=True)
m = linear('mlp_proj2', m, nx,
std=np.sqrt(H.mlp_w2 / inner_dim / H.n_layer * 0.5))
m = residual_dropout(m, train, key=f'{scopename}-m', pdrop=pdrop)
return bs.add(x, m)