def monkey_patch_layers()

in uimnet/modules/spectral_normalization/utils.py [0:0]


def monkey_patch_layers(module, sn_coef=1, sn_bn=False, verbose=True):
  """
  Applies spetral normalization to batch normalization layers
  """
  is_linear = isinstance(module, (torch.nn.Linear, ))
  is_conv = isinstance(module,
                       (nn.Conv1d, nn.Conv2d, nn.Conv3d)
                       )
  is_bn = isinstance(module,
                     (nn.BatchNorm1d, nn.BatchNorm2d, torch.nn.BatchNorm3d)
                     )
  if is_linear:
    utils.message(f'Wrapping linear module {module}')
    return SNLinear(in_features=module.in_features,
                    out_features=module.out_features,
                    bias=False if module.bias is None else True,
                    sn_coef=sn_coef,
                    num_itrs=1,
                    eps=1e-12
                    )
  elif is_conv:
    CLS = dict(Conv1d=SNConv1d, Conv2d=SNConv2d)[module.__class__.__name__]
    utils.message(f'Wrapping convolutional module {module}')
    return CLS(in_channels=module.in_channels,
                    out_channels=module.out_channels,
                    kernel_size=module.kernel_size,
                    stride=module.stride,
                    padding=module.padding,
                    dilation=module.dilation,
                    #bias=module.bias,
                    bias=False if module.bias is None else True,
                    sn_coef=sn_coef,
                    num_itrs=1,
                    eps=1e-12
                    )
  elif is_bn and sn_bn:
    # Swap batchnormalization for spectral batchnormalization

    CLS = dict(
      BatchNorm1d=SpectralBatchNorm1d,
      BatchNorm2d=SpectralBatchNorm2d,
      BatchNorm3d=SpectralBatchNorm3d)[module.__class__.__name__]

    utils.message(f'Wrapping Batch Norm layer {module}')
    return CLS(module.num_features, coeff=sn_coef,
               eps=1e-5,
               momentum=0.1,
               affine=True
               )
  else:
    return module