evaluations/inception_v3.py (303 lines of code) (raw):

# Ported from the model here: # https://github.com/NVlabs/stylegan3/blob/407db86e6fe432540a22515310188288687858fa/metrics/frechet_inception_distance.py#L22 # # I have verified that the spatial features and output features are correct # within a mean absolute error of ~3e-5. import collections import torch class Conv2dLayer(torch.nn.Module): def __init__(self, in_channels, out_channels, kh, kw, stride=1, padding=0): super().__init__() self.stride = stride self.padding = padding self.weight = torch.nn.Parameter(torch.zeros(out_channels, in_channels, kh, kw)) self.beta = torch.nn.Parameter(torch.zeros(out_channels)) self.mean = torch.nn.Parameter(torch.zeros(out_channels)) self.var = torch.nn.Parameter(torch.zeros(out_channels)) def forward(self, x): x = torch.nn.functional.conv2d( x, self.weight.to(x.dtype), stride=self.stride, padding=self.padding ) x = torch.nn.functional.batch_norm( x, running_mean=self.mean, running_var=self.var, bias=self.beta, eps=1e-3 ) x = torch.nn.functional.relu(x) return x # ---------------------------------------------------------------------------- class InceptionA(torch.nn.Module): def __init__(self, in_channels, tmp_channels): super().__init__() self.conv = Conv2dLayer(in_channels, 64, kh=1, kw=1) self.tower = torch.nn.Sequential( collections.OrderedDict( [ ("conv", Conv2dLayer(in_channels, 48, kh=1, kw=1)), ("conv_1", Conv2dLayer(48, 64, kh=5, kw=5, padding=2)), ] ) ) self.tower_1 = torch.nn.Sequential( collections.OrderedDict( [ ("conv", Conv2dLayer(in_channels, 64, kh=1, kw=1)), ("conv_1", Conv2dLayer(64, 96, kh=3, kw=3, padding=1)), ("conv_2", Conv2dLayer(96, 96, kh=3, kw=3, padding=1)), ] ) ) self.tower_2 = torch.nn.Sequential( collections.OrderedDict( [ ( "pool", torch.nn.AvgPool2d( kernel_size=3, stride=1, padding=1, count_include_pad=False ), ), ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)), ] ) ) def forward(self, x): return torch.cat( [ self.conv(x).contiguous(), self.tower(x).contiguous(), self.tower_1(x).contiguous(), self.tower_2(x).contiguous(), ], dim=1, ) # ---------------------------------------------------------------------------- class InceptionB(torch.nn.Module): def __init__(self, in_channels): super().__init__() self.conv = Conv2dLayer(in_channels, 384, kh=3, kw=3, stride=2) self.tower = torch.nn.Sequential( collections.OrderedDict( [ ("conv", Conv2dLayer(in_channels, 64, kh=1, kw=1)), ("conv_1", Conv2dLayer(64, 96, kh=3, kw=3, padding=1)), ("conv_2", Conv2dLayer(96, 96, kh=3, kw=3, stride=2)), ] ) ) self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2) def forward(self, x): return torch.cat( [ self.conv(x).contiguous(), self.tower(x).contiguous(), self.pool(x).contiguous(), ], dim=1, ) # ---------------------------------------------------------------------------- class InceptionC(torch.nn.Module): def __init__(self, in_channels, tmp_channels): super().__init__() self.conv = Conv2dLayer(in_channels, 192, kh=1, kw=1) self.tower = torch.nn.Sequential( collections.OrderedDict( [ ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)), ( "conv_1", Conv2dLayer( tmp_channels, tmp_channels, kh=1, kw=7, padding=[0, 3] ), ), ( "conv_2", Conv2dLayer(tmp_channels, 192, kh=7, kw=1, padding=[3, 0]), ), ] ) ) self.tower_1 = torch.nn.Sequential( collections.OrderedDict( [ ("conv", Conv2dLayer(in_channels, tmp_channels, kh=1, kw=1)), ( "conv_1", Conv2dLayer( tmp_channels, tmp_channels, kh=7, kw=1, padding=[3, 0] ), ), ( "conv_2", Conv2dLayer( tmp_channels, tmp_channels, kh=1, kw=7, padding=[0, 3] ), ), ( "conv_3", Conv2dLayer( tmp_channels, tmp_channels, kh=7, kw=1, padding=[3, 0] ), ), ( "conv_4", Conv2dLayer(tmp_channels, 192, kh=1, kw=7, padding=[0, 3]), ), ] ) ) self.tower_2 = torch.nn.Sequential( collections.OrderedDict( [ ( "pool", torch.nn.AvgPool2d( kernel_size=3, stride=1, padding=1, count_include_pad=False ), ), ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)), ] ) ) def forward(self, x): return torch.cat( [ self.conv(x).contiguous(), self.tower(x).contiguous(), self.tower_1(x).contiguous(), self.tower_2(x).contiguous(), ], dim=1, ) # ---------------------------------------------------------------------------- class InceptionD(torch.nn.Module): def __init__(self, in_channels): super().__init__() self.tower = torch.nn.Sequential( collections.OrderedDict( [ ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)), ("conv_1", Conv2dLayer(192, 320, kh=3, kw=3, stride=2)), ] ) ) self.tower_1 = torch.nn.Sequential( collections.OrderedDict( [ ("conv", Conv2dLayer(in_channels, 192, kh=1, kw=1)), ("conv_1", Conv2dLayer(192, 192, kh=1, kw=7, padding=[0, 3])), ("conv_2", Conv2dLayer(192, 192, kh=7, kw=1, padding=[3, 0])), ("conv_3", Conv2dLayer(192, 192, kh=3, kw=3, stride=2)), ] ) ) self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2) def forward(self, x): return torch.cat( [ self.tower(x).contiguous(), self.tower_1(x).contiguous(), self.pool(x).contiguous(), ], dim=1, ) # ---------------------------------------------------------------------------- class InceptionE(torch.nn.Module): def __init__(self, in_channels, use_avg_pool): super().__init__() self.conv = Conv2dLayer(in_channels, 320, kh=1, kw=1) self.tower_conv = Conv2dLayer(in_channels, 384, kh=1, kw=1) self.tower_mixed_conv = Conv2dLayer(384, 384, kh=1, kw=3, padding=[0, 1]) self.tower_mixed_conv_1 = Conv2dLayer(384, 384, kh=3, kw=1, padding=[1, 0]) self.tower_1_conv = Conv2dLayer(in_channels, 448, kh=1, kw=1) self.tower_1_conv_1 = Conv2dLayer(448, 384, kh=3, kw=3, padding=1) self.tower_1_mixed_conv = Conv2dLayer(384, 384, kh=1, kw=3, padding=[0, 1]) self.tower_1_mixed_conv_1 = Conv2dLayer(384, 384, kh=3, kw=1, padding=[1, 0]) if use_avg_pool: self.tower_2_pool = torch.nn.AvgPool2d( kernel_size=3, stride=1, padding=1, count_include_pad=False ) else: self.tower_2_pool = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) self.tower_2_conv = Conv2dLayer(in_channels, 192, kh=1, kw=1) def forward(self, x): a = self.tower_conv(x) b = self.tower_1_conv_1(self.tower_1_conv(x)) return torch.cat( [ self.conv(x).contiguous(), self.tower_mixed_conv(a).contiguous(), self.tower_mixed_conv_1(a).contiguous(), self.tower_1_mixed_conv(b).contiguous(), self.tower_1_mixed_conv_1(b).contiguous(), self.tower_2_conv(self.tower_2_pool(x)).contiguous(), ], dim=1, ) # ---------------------------------------------------------------------------- class InceptionV3(torch.nn.Module): def __init__(self): super().__init__() self.layers = torch.nn.Sequential( collections.OrderedDict( [ ("conv", Conv2dLayer(3, 32, kh=3, kw=3, stride=2)), ("conv_1", Conv2dLayer(32, 32, kh=3, kw=3)), ("conv_2", Conv2dLayer(32, 64, kh=3, kw=3, padding=1)), ("pool0", torch.nn.MaxPool2d(kernel_size=3, stride=2)), ("conv_3", Conv2dLayer(64, 80, kh=1, kw=1)), ("conv_4", Conv2dLayer(80, 192, kh=3, kw=3)), ("pool1", torch.nn.MaxPool2d(kernel_size=3, stride=2)), ("mixed", InceptionA(192, tmp_channels=32)), ("mixed_1", InceptionA(256, tmp_channels=64)), ("mixed_2", InceptionA(288, tmp_channels=64)), ("mixed_3", InceptionB(288)), ("mixed_4", InceptionC(768, tmp_channels=128)), ("mixed_5", InceptionC(768, tmp_channels=160)), ("mixed_6", InceptionC(768, tmp_channels=160)), ("mixed_7", InceptionC(768, tmp_channels=192)), ("mixed_8", InceptionD(768)), ("mixed_9", InceptionE(1280, use_avg_pool=True)), ("mixed_10", InceptionE(2048, use_avg_pool=False)), ("pool2", torch.nn.AvgPool2d(kernel_size=8)), ] ) ) self.output = torch.nn.Linear(2048, 1008) def forward( self, img, return_features: bool = True, use_fp16: bool = False, no_output_bias: bool = False, ): batch_size, channels, height, width = img.shape # [NCHW] assert channels == 3 # Cast to float. x = img.to(torch.float16 if use_fp16 else torch.float32) # Emulate tf.image.resize_bilinear(x, [299, 299]), including the funky alignment. new_width, new_height = 299, 299 theta = torch.eye(2, 3, device=x.device) theta[0, 2] += theta[0, 0] / width - theta[0, 0] / new_width theta[1, 2] += theta[1, 1] / height - theta[1, 1] / new_height theta = theta.to(x.dtype).unsqueeze(0).repeat([batch_size, 1, 1]) grid = torch.nn.functional.affine_grid( theta, [batch_size, channels, new_height, new_width], align_corners=False ) x = torch.nn.functional.grid_sample( x, grid, mode="bilinear", padding_mode="border", align_corners=False ) # Scale dynamic range from [0,255] to [-1,1[. x -= 128 x /= 128 # Main layers. intermediate = self.layers[:-6](x) spatial_features = ( self.layers[-6] .conv(intermediate)[:, :7] .permute(0, 2, 3, 1) .reshape(-1, 2023) ) features = self.layers[-6:](intermediate).reshape(-1, 2048).to(torch.float32) if return_features: return features, spatial_features # Output layer. return self.acts_to_probs(features, no_output_bias=no_output_bias) def acts_to_probs(self, features, no_output_bias: bool = False): if no_output_bias: logits = torch.nn.functional.linear(features, self.output.weight) else: logits = self.output(features) probs = torch.nn.functional.softmax(logits, dim=1) return probs def create_softmax_model(self): return SoftmaxModel(self.output.weight) class SoftmaxModel(torch.nn.Module): def __init__(self, weight: torch.Tensor): super().__init__() self.weight = torch.nn.Parameter(weight.detach().clone()) def forward(self, x): logits = torch.nn.functional.linear(x, self.weight) probs = torch.nn.functional.softmax(logits, dim=1) return probs