in train.py [0:0]
def linear(scope, x, nf, std, relu=False, fast_gelu=False):
with tf.variable_scope(scope):
nx = x.shape[-1].value
# delay w casting operation just prior to use
# This can save a lot of memory for large param models.
with tf.control_dependencies([x]):
w = tf.get_variable("w", [nx, nf], dtype=H.dtype,
initializer=random_or_zeros_init(stddev=std))
b = tf.get_variable("b", [nf], dtype=tf.float32,
initializer=zeros_init())
ndims = x.shape.ndims
if ndims > 2:
h_shape = tf.concat([tf.shape(x)[:ndims - 1], [nf]], axis=0)
x = tf.reshape(x, [-1, nx])
h = tf.matmul(x, w)
h = bs.bias_relu(h, b, relu=relu, fast_gelu=fast_gelu)
if ndims > 2:
h = tf.reshape(h, h_shape)
return h