def post_attention()

in train.py [0:0]


def post_attention(x, a, use_cache=None, train=False, pdrop=None):

    nx = x.shape[-1].value

    a = linear('post_proj', a, nx,
               std=np.sqrt(H.post_w * 0.5 / nx / H.n_layer))
    scopename = tf.get_variable_scope().name
    a = residual_dropout(a, train, key=f'{scopename}-a', pdrop=pdrop)

    x = x[:, -1:, :] if use_cache else x
    x = bs.add(x, a)

    inner_dim = int(nx * H.mlp_multiple)

    m = norm("mlp", x)
    m = linear('mlp_proj1', m, inner_dim,
               std=np.sqrt(H.mlp_w1 / nx), fast_gelu=True)
    m = linear('mlp_proj2', m, nx,
               std=np.sqrt(H.mlp_w2 / inner_dim / H.n_layer * 0.5))
    m = residual_dropout(m, train, key=f'{scopename}-m', pdrop=pdrop)

    return bs.add(x, m)