in luckmatter/recon_multilayer.py [0:0]
def compute_Hs(net1, output1, net2, output2):
# Compute H given the current banch.
sz1 = net1.sizes
sz2 = net2.sizes
bs = output1["hs"][0].size(0)
assert sz1[-1] == sz2[-1], "the output size of both network should be the same: %d vs %d" % (sz1[-1], sz2[-1])
H = torch.cuda.FloatTensor(bs, sz1[-1], sz2[-1])
for i in range(bs):
H[i,:,:] = torch.eye(sz1[-1]).cuda()
Hs = []
betas = []
# Then we try computing the other rels recursively.
j = len(output1["hs"])
pre_bns1 = output1["pre_bns"][::-1]
pre_bns2 = output2["pre_bns"][::-1]
for pre_bn1, pre_bn2 in zip(pre_bns1, pre_bns2):
# W: of size [output, input]
W1t = net1.from_bottom_linear(j).t()
W2 = net2.from_bottom_linear(j)
# [bs, input_dim_net1, input_dim_net2]
beta = torch.cuda.FloatTensor(bs, W1t.size(0), W2.size(1))
for i in range(bs):
beta[i, :, :] = W1t @ H[i, :, :] @ W2
# H_new = torch.bmm(torch.bmm(W1, H), W2)
betas.append(beta.mean(0).cpu())
gate2 = (pre_bn2 > 0).float()
if net2.has_bn:
bn2 = BN(net2.from_bottom_bn(j - 1), pre_bn2)
gate2 = bn2.forwardJ(gate2)
AA = beta * gate2[:, None, :]
if net1.has_bn:
# pre_bn: [bs, input_dim]
bn1 = BN(net1.from_bottom_bn(j - 1), pre_bn1)
# gate: [bs, input_dim]
for k in range(AA.size(2)):
AA[:,:,k] = bn1.backwardJ(AA[:,:,k])
gate1 = (pre_bn1 > 0).float()
H = gate1[:, :, None] * AA
Hs.append(H.mean(0).cpu())
j -= 1
return Hs[::-1], betas[::-1]