in lib/xf.py [0:0]
def __init__(self, x_size, scale, dtype, norm, actname="relu", mlp_ratio=2):
super().__init__()
s = math.sqrt(scale)
self.ln = util.get_norm(norm, x_size, dtype=dtype)
self.mlp = mlp.MLP(
insize=x_size,
nhidlayer=1,
outsize=x_size,
hidsize=int(x_size * mlp_ratio),
hidactiv=functools.partial(act, actname),
dtype=dtype,
)
self.mlp.layers[0].weight.data *= MLP0_SCALE * s
self.mlp.layers[1].weight.data *= MLP1_SCALE * s