in lib/torch_util.py [0:0]
def NormedLinear(*args, scale=1.0, dtype=th.float32, **kwargs):
"""
nn.Linear but with normalized fan-in init
"""
dtype = parse_dtype(dtype)
if dtype == th.float32:
out = nn.Linear(*args, **kwargs)
elif dtype == th.float16:
out = LinearF16(*args, **kwargs)
else:
raise ValueError(dtype)
out.weight.data *= scale / out.weight.norm(dim=1, p=2, keepdim=True)
if kwargs.get("bias", True):
out.bias.data *= 0
return out