in models/utils.py [0:0]
def initmod(m, gain=1.0, weightinitfunc=xavier_uniform_):
validclasses = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]
if any([isinstance(m, x) for x in validclasses]):
weightinitfunc(m, gain)
if hasattr(m, 'bias'):
m.bias.data.zero_()
# blockwise initialization for transposed convs
if isinstance(m, nn.ConvTranspose2d):
# hardcoded for stride=2 for now
m.weight.data[:, :, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2]
m.weight.data[:, :, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2]
m.weight.data[:, :, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2]
if isinstance(m, nn.ConvTranspose3d):
# hardcoded for stride=2 for now
m.weight.data[:, :, 0::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2]
m.weight.data[:, :, 0::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2]
m.weight.data[:, :, 0::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2]
m.weight.data[:, :, 1::2, 0::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2]
m.weight.data[:, :, 1::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2]
m.weight.data[:, :, 1::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2]
m.weight.data[:, :, 1::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2]