def compute_Hs()

in luckmatter/recon_multilayer.py [0:0]


def compute_Hs(net1, output1, net2, output2):
    # Compute H given the current banch.
    sz1 = net1.sizes
    sz2 = net2.sizes
    
    bs = output1["hs"][0].size(0)
    
    assert sz1[-1] == sz2[-1], "the output size of both network should be the same: %d vs %d" % (sz1[-1], sz2[-1])

    H = torch.cuda.FloatTensor(bs, sz1[-1], sz2[-1])
    for i in range(bs):
        H[i,:,:] = torch.eye(sz1[-1]).cuda()

    Hs = []
    betas = []

    # Then we try computing the other rels recursively.
    j = len(output1["hs"])
    pre_bns1 = output1["pre_bns"][::-1]
    pre_bns2 = output2["pre_bns"][::-1]

    for pre_bn1, pre_bn2 in zip(pre_bns1, pre_bns2):
        # W: of size [output, input]
        W1t = net1.from_bottom_linear(j).t()
        W2 = net2.from_bottom_linear(j)

        # [bs, input_dim_net1, input_dim_net2]
        beta = torch.cuda.FloatTensor(bs, W1t.size(0), W2.size(1))
        for i in range(bs):
            beta[i, :, :] = W1t @ H[i, :, :] @ W2
        # H_new = torch.bmm(torch.bmm(W1, H), W2)

        betas.append(beta.mean(0).cpu())

        gate2 = (pre_bn2 > 0).float()
        if net2.has_bn:
            bn2 = BN(net2.from_bottom_bn(j - 1), pre_bn2)
            gate2 = bn2.forwardJ(gate2)

        AA = beta * gate2[:, None, :]

        if net1.has_bn:
            # pre_bn: [bs, input_dim]
            bn1 = BN(net1.from_bottom_bn(j - 1), pre_bn1)
            # gate: [bs, input_dim]
            for k in range(AA.size(2)):
                AA[:,:,k] = bn1.backwardJ(AA[:,:,k])
        gate1 = (pre_bn1 > 0).float()

        H = gate1[:, :, None] * AA
        Hs.append(H.mean(0).cpu())
        j -= 1

    return Hs[::-1], betas[::-1]