in Classification/models/rfnorm.py [0:0]
def forward(self, x,output_size=None):
if x.numel()==0:
return x
x=x.contiguous()
N, C, H, W = x.shape
B = self.block
if self.norm_type=='l1norm':
x_norm=x.abs().mean(dim=(1,2,3),keepdim=True)
x = x/ (x_norm + self.eps)
elif self.norm_type=='layernorm':
x=self.layernorm(x)
elif self.norm_type=='rfnorm': #receptive field normalization
_,_,h,w=x.shape
x=x.contiguous()
win_size=self.rf_size#(self.kernel_size[0]-1)*self.stride[0]+1
ones=torch.ones(1,1,h,w,dtype=x.dtype,device=x.device)
M=box_filter(ones,win_size)
x_mean=box_filter(x,win_size).mean(dim=1,keepdim=True)/M
x2_mean=box_filter(x**2,win_size).mean(dim=1,keepdim=True)/M
var = torch.clamp(x2_mean - x_mean**2,min=0.)
std = var.sqrt()+self.rf_eps
rf_a = 1/ std
rf_b = -x_mean/ std* self.weight.sum(dim=(1,2,3)).view(1,-1,1,1)
if self.training:
self.counter+=1
#1. im2col: N x cols x pixels -> N*pixles x cols
if self.kernel_size[0]>1:
#the adjoint of a conv operation is a full correlation operation. So pad first.
padding=(self.padding[0]+self.kernel_size[0]-1,self.padding[1]+self.kernel_size[1]-1)
X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
if self.norm_type=='rfnorm':
X_std=X.std(dim=-1,keepdim=True)+self.rf_eps
X=(X-X.mean(dim=-1,keepdim=True))/X_std
else:
#channel wise
X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]
if self.groups==1:
# (C//B*N*pixels,k*k*B)
X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
else:
X=X.view(-1,X.shape[-1])
if self.sync:
process_group = self.process_group
world_size = 1
if not self.process_group:
process_group = torch.distributed.group.WORLD
# 2. calculate mean,cov,cov_isqrt
X_mean = X.mean(0)
if self.groups==1:
M = X.shape[0]
XX_mean = X.t()@X/M
else:
M = X.shape[1]
XX_mean = X.transpose(1,2)@X/M
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size(process_group)
#sync once implementation:
sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
sync_data=torch.stack(sync_data_list).mean(0)
X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)
if self.groups==1:
cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)
Id = torch.eye(cov.shape[1], dtype=cov.dtype, device=cov.device)
cov_isqrt = isqrt_newton_schulz_autograd(cov+self.eps*Id, self.n_iter)
else:
cov= (XX_mean- (X_mean.unsqueeze(2)) @X_mean.unsqueeze(1))
Id = torch.eye(self.num_features, dtype=cov.dtype, device=cov.device).expand(self.groups, self.num_features, self.num_features)
cov_isqrt = isqrt_newton_schulz_autograd_batch(cov+self.eps*Id, self.n_iter)
else:
X_mean = X.mean(0)
X = X - X_mean.unsqueeze(0)
if self.groups==1:
#cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
cov = torch.addmm(beta=self.eps, input=Id, alpha=1. / X.shape[0], mat1=X.t(), mat2=X)
cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
else:
X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
cov = torch.baddbmm(beta=self.eps, input=Id, alpha=1. / X.shape[1], mat1=X.transpose(1, 2), mat2=X)
cov_isqrt = isqrt_newton_schulz_autograd_batch(cov, self.n_iter)
self.running_mean.mul_(1 - self.momentum)
self.running_mean.add_(X_mean.detach() * self.momentum)
self.running_cov_isqrt.mul_(1 - self.momentum)
self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)
else:
X_mean = self.running_mean
cov_isqrt = self.running_cov_isqrt
#3. X * deconv * conv = X * (deconv * conv)
# this is to use conv2d to calculate conv_tansposed2d
weight=torch.flip(self.weight,[2,3])
if self.groups==1:
w = weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ cov_isqrt
b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
else:
w = self.weight.view(C//B, -1,self.num_features)@cov_isqrt
b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)
w = w.view(self.weight.shape)
w=torch.flip(w.view(weight.shape),[2,3])
output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size)
if self.norm_type=='rfnorm':
x = F.conv_transpose2d(
x, w, None, self.stride, self.padding,
output_padding, self.groups, self.dilation)
if rf_a.shape[-2:]!=x.shape[-2:]:
rf_a=F.interpolate(rf_a,size=x.shape[-2:],mode='bilinear')
rf_b=F.interpolate(rf_b,size=x.shape[-2:],mode='bilinear')
x=x*rf_a+rf_b+b.view(1,-1,1,1)
else:
x = F.conv_transpose2d(
x, w, b, self.stride, self.padding,
output_padding, self.groups, self.dilation)
return x