def linear()

in train.py [0:0]


def linear(scope, x, nf, std, relu=False, fast_gelu=False):

    with tf.variable_scope(scope):

        nx = x.shape[-1].value

        # delay w casting operation just prior to use
        # This can save a lot of memory for large param models.
        with tf.control_dependencies([x]):
            w = tf.get_variable("w", [nx, nf], dtype=H.dtype,
                                initializer=random_or_zeros_init(stddev=std))
            b = tf.get_variable("b", [nf], dtype=tf.float32,
                                initializer=zeros_init())

        ndims = x.shape.ndims
        if ndims > 2:
            h_shape = tf.concat([tf.shape(x)[:ndims - 1], [nf]], axis=0)
            x = tf.reshape(x, [-1, nx])

        h = tf.matmul(x, w)
        h = bs.bias_relu(h, b, relu=relu, fast_gelu=fast_gelu)
        if ndims > 2:
            h = tf.reshape(h, h_shape)
    return h