in glide_text2im/xf.py [0:0]
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()