Dassl.pytorch/dassl/modeling/network/ddaig_fcn.py (269 lines of code) (raw):

""" Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix """ import functools import torch import torch.nn as nn from torch.nn import functional as F from .build import NETWORK_REGISTRY def init_network_weights(model, init_type="normal", gain=0.02): def _init_func(m): classname = m.__class__.__name__ if hasattr(m, "weight") and ( classname.find("Conv") != -1 or classname.find("Linear") != -1 ): if init_type == "normal": nn.init.normal_(m.weight.data, 0.0, gain) elif init_type == "xavier": nn.init.xavier_normal_(m.weight.data, gain=gain) elif init_type == "kaiming": nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif init_type == "orthogonal": nn.init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError( "initialization method {} is not implemented". format(init_type) ) if hasattr(m, "bias") and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif classname.find("BatchNorm2d") != -1: nn.init.constant_(m.weight.data, 1.0) nn.init.constant_(m.bias.data, 0.0) elif classname.find("InstanceNorm2d") != -1: if m.weight is not None and m.bias is not None: nn.init.constant_(m.weight.data, 1.0) nn.init.constant_(m.bias.data, 0.0) model.apply(_init_func) def get_norm_layer(norm_type="instance"): if norm_type == "batch": norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == "instance": norm_layer = functools.partial( nn.InstanceNorm2d, affine=False, track_running_stats=False ) elif norm_type == "none": norm_layer = None else: raise NotImplementedError( "normalization layer [%s] is not found" % norm_type ) return norm_layer class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): super().__init__() self.conv_block = self.build_conv_block( dim, padding_type, norm_layer, use_dropout, use_bias ) def build_conv_block( self, dim, padding_type, norm_layer, use_dropout, use_bias ): conv_block = [] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError( "padding [%s] is not implemented" % padding_type ) conv_block += [ nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True), ] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError( "padding [%s] is not implemented" % padding_type ) conv_block += [ nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), ] return nn.Sequential(*conv_block) def forward(self, x): return x + self.conv_block(x) class LocNet(nn.Module): """Localization network.""" def __init__( self, input_nc, nc=32, n_blocks=3, use_dropout=False, padding_type="zero", image_size=32, ): super().__init__() backbone = [] backbone += [ nn.Conv2d( input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False ) ] backbone += [nn.BatchNorm2d(nc)] backbone += [nn.ReLU(True)] for _ in range(n_blocks): backbone += [ ResnetBlock( nc, padding_type=padding_type, norm_layer=nn.BatchNorm2d, use_dropout=use_dropout, use_bias=False, ) ] backbone += [nn.MaxPool2d(2, stride=2)] self.backbone = nn.Sequential(*backbone) reduced_imsize = int(image_size * 0.5**(n_blocks + 1)) self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2) def forward(self, x): x = self.backbone(x) x = x.view(x.size(0), -1) x = self.fc_loc(x) x = torch.tanh(x) x = x.view(-1, 2, 2) theta = x.data.new_zeros(x.size(0), 2, 3) theta[:, :, :2] = x return theta class FCN(nn.Module): """Fully convolutional network.""" def __init__( self, input_nc, output_nc, nc=32, n_blocks=3, norm_layer=nn.BatchNorm2d, use_dropout=False, padding_type="reflect", gctx=True, stn=False, image_size=32, ): super().__init__() backbone = [] p = 0 if padding_type == "reflect": backbone += [nn.ReflectionPad2d(1)] elif padding_type == "replicate": backbone += [nn.ReplicationPad2d(1)] elif padding_type == "zero": p = 1 else: raise NotImplementedError backbone += [ nn.Conv2d( input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False ) ] backbone += [norm_layer(nc)] backbone += [nn.ReLU(True)] for _ in range(n_blocks): backbone += [ ResnetBlock( nc, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=False, ) ] self.backbone = nn.Sequential(*backbone) # global context fusion layer self.gctx_fusion = None if gctx: self.gctx_fusion = nn.Sequential( nn.Conv2d( 2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False ), norm_layer(nc), nn.ReLU(True), ) self.regress = nn.Sequential( nn.Conv2d( nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True ), nn.Tanh(), ) self.locnet = None if stn: self.locnet = LocNet( input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size ) def init_loc_layer(self): """Initialize the weights/bias with identity transformation.""" if self.locnet is not None: self.locnet.fc_loc.weight.data.zero_() self.locnet.fc_loc.bias.data.copy_( torch.tensor([1, 0, 0, 1], dtype=torch.float) ) def stn(self, x): """Spatial transformer network.""" theta = self.locnet(x) grid = F.affine_grid(theta, x.size()) return F.grid_sample(x, grid), theta def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False): """ Args: x (torch.Tensor): input mini-batch. lmda (float): multiplier for perturbation. return_p (bool): return perturbation. return_stn_output (bool): return the output of stn. """ theta = None if self.locnet is not None: x, theta = self.stn(x) input = x x = self.backbone(x) if self.gctx_fusion is not None: c = F.adaptive_avg_pool2d(x, (1, 1)) c = c.expand_as(x) x = torch.cat([x, c], 1) x = self.gctx_fusion(x) p = self.regress(x) x_p = input + lmda*p if return_stn_output: return x_p, p, input if return_p: return x_p, p return x_p @NETWORK_REGISTRY.register() def fcn_3x32_gctx(**kwargs): norm_layer = get_norm_layer(norm_type="instance") net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer) init_network_weights(net, init_type="normal", gain=0.02) return net @NETWORK_REGISTRY.register() def fcn_3x64_gctx(**kwargs): norm_layer = get_norm_layer(norm_type="instance") net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer) init_network_weights(net, init_type="normal", gain=0.02) return net @NETWORK_REGISTRY.register() def fcn_3x32_gctx_stn(image_size=32, **kwargs): norm_layer = get_norm_layer(norm_type="instance") net = FCN( 3, 3, nc=32, n_blocks=3, norm_layer=norm_layer, stn=True, image_size=image_size ) init_network_weights(net, init_type="normal", gain=0.02) net.init_loc_layer() return net @NETWORK_REGISTRY.register() def fcn_3x64_gctx_stn(image_size=224, **kwargs): norm_layer = get_norm_layer(norm_type="instance") net = FCN( 3, 3, nc=64, n_blocks=3, norm_layer=norm_layer, stn=True, image_size=image_size ) init_network_weights(net, init_type="normal", gain=0.02) net.init_loc_layer() return net