in generation/models/networks.py [0:0]
def forward(self, input, inst, is_background=False):
outputs = self.model(input)
outputs_mu = self.fc1(outputs)
if is_background:
# (B, C, H, W) -> (B, C)
return torch.mean(torch.mean(outputs_mu, dim=3), dim=2)
else:
# instance-wise average pooling
outputs_mean_mu = torch.zeros_like(outputs_mu)
# outputs_mean_logvar = torch.zeros_like(outputs_logvar) # We won't use logvar
inst_list = np.unique(inst.cpu().numpy().astype(int))
for i in inst_list:
for b in range(input.size()[0]):
indices = (inst[b:b+1] == int(i)).nonzero() # n (row) x 4 (col) matrix, for example: [[0,0,0,0],[0,0,0,1],....[0,0,255,254],[0,0,255,255]]
# each of the n row is a position where inst == i, and the columns in each row specify the batch-idx, channel-idx, x-pos, y-pos of the pixel
if indices.nelement() == 0: # If this image b does not have label i in it, then no need to go though following for loop
continue
for j in range(self.output_nc): # Need the for loop because we cannot access indices[:,1]:indices[:,1]+self.output_nc at the same time
# mu
output_mu_ins = outputs_mu[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]
mean_feat = torch.mean(output_mu_ins).expand_as(output_mu_ins)
outputs_mean_mu[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat
return outputs_mean_mu