in generation/models/networks.py [0:0]
def forward_and_reparameterize(self, input, inst, bk_z):
outputs = self.model(input)
outputs_mu = self.fc1(outputs)
outputs_logvar = self.fc2(outputs)
# instance-wise average pooling
outputs_mean = torch.zeros_like(outputs_mu)
inst_list = np.unique(inst.cpu().numpy().astype(int))
# Processing background
if (bk_z is not None) and (0 in inst_list):
# 1) # Exclude background
inst_list = inst_list[1:]
# 2) Broadcast pre-encoded background code to background indices
for b in range(input.size()[0]):
indices = (inst[b:b+1] == 0).nonzero() # n (row) x 4 (col) matrix, for example: [[0,0,0,0],[0,0,0,1],....[0,0,255,254],[0,0,255,255]]
# each of the n row is a position where inst == i, and the columns in each row specify the batch-idx, channel-idx, x-pos, y-pos of the pixel
if indices.nelement() == 0: # If this image b does not have label i in it, then no need to go though following for loop
continue
for j in range(self.output_nc): # Two for loops because reparameterization trick needs to be done on the entire vector, not a single dimension j
output_ins = outputs_logvar[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]
reparam_mean_feat = bk_z[b, j].expand_as(output_ins) # CHECK bk_z shape
outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = reparam_mean_feat
batch_mu = torch.zeros(input.size()[0], outputs_mu.size()[1], len(inst_list)).cuda() # shape is: batchSize, feature dimension, number of unique labels
batch_logvar = torch.zeros(input.size()[0], outputs_mu.size()[1], len(inst_list)).cuda()
for count_i,i in enumerate(inst_list):
for b in range(input.size()[0]):
indices = (inst[b:b+1] == int(i)).nonzero() # n (row) x 4 (col) matrix, for example: [[0,0,0,0],[0,0,0,1],....[0,0,255,254],[0,0,255,255]]
# each of the n row is a position where inst == i, and the columns in each row specify the batch-idx, channel-idx, x-pos, y-pos of the pixel
if indices.nelement() == 0: # If this image b does not have label i in it, then no need to go though following for loop
continue
for j in range(self.output_nc): # Need the for loop because we cannot access indices[:,1]:indices[:,1]+self.output_nc at the same time
# mu
output_mu_ins = outputs_mu[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]
mean_mu = torch.mean(output_mu_ins)
batch_mu[b, j, count_i] = mean_mu
# logvar
output_logvar_ins = outputs_logvar[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]
mean_logvar = torch.mean(output_logvar_ins)
batch_logvar[b, j, count_i] = mean_logvar
# reparametrization trick
reparam_mean = self.reparameterize(batch_mu[b, :, count_i].clone(), batch_logvar[b, :, count_i].clone()) # make sure the reparameterize does not modify their values in-place
for j in range(self.output_nc): # Two for loops because reparameterization trick needs to be done on the entire vector, not a single dimension j
output_ins = outputs_logvar[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]
reparam_mean_feat = reparam_mean[j].expand_as(output_ins)
outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = reparam_mean_feat
return outputs_mean, batch_mu, batch_logvar, inst_list