Dassl.pytorch/dassl/modeling/network/ddaig_fcn.py (269 lines of code) (raw):
"""
Credit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
"""
import functools
import torch
import torch.nn as nn
from torch.nn import functional as F
from .build import NETWORK_REGISTRY
def init_network_weights(model, init_type="normal", gain=0.02):
def _init_func(m):
classname = m.__class__.__name__
if hasattr(m, "weight") and (
classname.find("Conv") != -1 or classname.find("Linear") != -1
):
if init_type == "normal":
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == "xavier":
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == "kaiming":
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
elif init_type == "orthogonal":
nn.init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError(
"initialization method {} is not implemented".
format(init_type)
)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
elif classname.find("InstanceNorm2d") != -1:
if m.weight is not None and m.bias is not None:
nn.init.constant_(m.weight.data, 1.0)
nn.init.constant_(m.bias.data, 0.0)
model.apply(_init_func)
def get_norm_layer(norm_type="instance"):
if norm_type == "batch":
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == "instance":
norm_layer = functools.partial(
nn.InstanceNorm2d, affine=False, track_running_stats=False
)
elif norm_type == "none":
norm_layer = None
else:
raise NotImplementedError(
"normalization layer [%s] is not found" % norm_type
)
return norm_layer
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super().__init__()
self.conv_block = self.build_conv_block(
dim, padding_type, norm_layer, use_dropout, use_bias
)
def build_conv_block(
self, dim, padding_type, norm_layer, use_dropout, use_bias
):
conv_block = []
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
"padding [%s] is not implemented" % padding_type
)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True),
]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == "reflect":
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError(
"padding [%s] is not implemented" % padding_type
)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
]
return nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class LocNet(nn.Module):
"""Localization network."""
def __init__(
self,
input_nc,
nc=32,
n_blocks=3,
use_dropout=False,
padding_type="zero",
image_size=32,
):
super().__init__()
backbone = []
backbone += [
nn.Conv2d(
input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False
)
]
backbone += [nn.BatchNorm2d(nc)]
backbone += [nn.ReLU(True)]
for _ in range(n_blocks):
backbone += [
ResnetBlock(
nc,
padding_type=padding_type,
norm_layer=nn.BatchNorm2d,
use_dropout=use_dropout,
use_bias=False,
)
]
backbone += [nn.MaxPool2d(2, stride=2)]
self.backbone = nn.Sequential(*backbone)
reduced_imsize = int(image_size * 0.5**(n_blocks + 1))
self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2)
def forward(self, x):
x = self.backbone(x)
x = x.view(x.size(0), -1)
x = self.fc_loc(x)
x = torch.tanh(x)
x = x.view(-1, 2, 2)
theta = x.data.new_zeros(x.size(0), 2, 3)
theta[:, :, :2] = x
return theta
class FCN(nn.Module):
"""Fully convolutional network."""
def __init__(
self,
input_nc,
output_nc,
nc=32,
n_blocks=3,
norm_layer=nn.BatchNorm2d,
use_dropout=False,
padding_type="reflect",
gctx=True,
stn=False,
image_size=32,
):
super().__init__()
backbone = []
p = 0
if padding_type == "reflect":
backbone += [nn.ReflectionPad2d(1)]
elif padding_type == "replicate":
backbone += [nn.ReplicationPad2d(1)]
elif padding_type == "zero":
p = 1
else:
raise NotImplementedError
backbone += [
nn.Conv2d(
input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False
)
]
backbone += [norm_layer(nc)]
backbone += [nn.ReLU(True)]
for _ in range(n_blocks):
backbone += [
ResnetBlock(
nc,
padding_type=padding_type,
norm_layer=norm_layer,
use_dropout=use_dropout,
use_bias=False,
)
]
self.backbone = nn.Sequential(*backbone)
# global context fusion layer
self.gctx_fusion = None
if gctx:
self.gctx_fusion = nn.Sequential(
nn.Conv2d(
2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False
),
norm_layer(nc),
nn.ReLU(True),
)
self.regress = nn.Sequential(
nn.Conv2d(
nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True
),
nn.Tanh(),
)
self.locnet = None
if stn:
self.locnet = LocNet(
input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size
)
def init_loc_layer(self):
"""Initialize the weights/bias with identity transformation."""
if self.locnet is not None:
self.locnet.fc_loc.weight.data.zero_()
self.locnet.fc_loc.bias.data.copy_(
torch.tensor([1, 0, 0, 1], dtype=torch.float)
)
def stn(self, x):
"""Spatial transformer network."""
theta = self.locnet(x)
grid = F.affine_grid(theta, x.size())
return F.grid_sample(x, grid), theta
def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):
"""
Args:
x (torch.Tensor): input mini-batch.
lmda (float): multiplier for perturbation.
return_p (bool): return perturbation.
return_stn_output (bool): return the output of stn.
"""
theta = None
if self.locnet is not None:
x, theta = self.stn(x)
input = x
x = self.backbone(x)
if self.gctx_fusion is not None:
c = F.adaptive_avg_pool2d(x, (1, 1))
c = c.expand_as(x)
x = torch.cat([x, c], 1)
x = self.gctx_fusion(x)
p = self.regress(x)
x_p = input + lmda*p
if return_stn_output:
return x_p, p, input
if return_p:
return x_p, p
return x_p
@NETWORK_REGISTRY.register()
def fcn_3x32_gctx(**kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer)
init_network_weights(net, init_type="normal", gain=0.02)
return net
@NETWORK_REGISTRY.register()
def fcn_3x64_gctx(**kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer)
init_network_weights(net, init_type="normal", gain=0.02)
return net
@NETWORK_REGISTRY.register()
def fcn_3x32_gctx_stn(image_size=32, **kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(
3,
3,
nc=32,
n_blocks=3,
norm_layer=norm_layer,
stn=True,
image_size=image_size
)
init_network_weights(net, init_type="normal", gain=0.02)
net.init_loc_layer()
return net
@NETWORK_REGISTRY.register()
def fcn_3x64_gctx_stn(image_size=224, **kwargs):
norm_layer = get_norm_layer(norm_type="instance")
net = FCN(
3,
3,
nc=64,
n_blocks=3,
norm_layer=norm_layer,
stn=True,
image_size=image_size
)
init_network_weights(net, init_type="normal", gain=0.02)
net.init_loc_layer()
return net