in convit.py [0:0]
def local_init(self, locality_strength=1.):
self.v.weight.data.copy_(torch.eye(self.dim))
locality_distance = 1 #max(1,1/locality_strength**.5)
kernel_size = int(self.num_heads**.5)
center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2
for h1 in range(kernel_size):
for h2 in range(kernel_size):
position = h1+kernel_size*h2
self.pos_proj.weight.data[position,2] = -1
self.pos_proj.weight.data[position,1] = 2*(h1-center)*locality_distance
self.pos_proj.weight.data[position,0] = 2*(h2-center)*locality_distance
self.pos_proj.weight.data *= locality_strength