def fuse()

in model/paper_txt2pi.py [0:0]


    def fuse(self, inputs, cell, inv, wiki, task):
        T, B, H, W, demb = cell.size()
        tb = torch.flatten(cell, 0, 1)  # (T*B, H, W, 3*demb)
        pos = inputs['rel_pos'].float().view(T*B, H, W, -1).transpose(1, 3)

        wiki1, wiki2 = wiki
        wiki_lens = inputs['wiki_len'].view(-1).long()
        wiki_attn, _ = self.run_attn(wiki1, wiki_lens, cond=task)

        c0 = tb.transpose(1, 3)  # (T*B, demb, W, H)
        s0 = self.c0_trans(torch.cat([c0, pos], dim=1).max(3)[0].max(2)[0])
        a0, _ = self.run_attn(wiki2, wiki_lens, cond=s0)
        c1, s1 = self.film1(c0, a0, inv, task, pos, wiki_attn)
        a1, _ = self.run_attn(wiki2, wiki_lens, cond=s1)
        c2, s2 = self.film2(c1, a1, inv, task, pos, wiki_attn)
        a2, _ = self.run_attn(wiki2, wiki_lens, cond=s2)
        c3, s3 = self.film3(c2, a2, inv, task, pos, wiki_attn)
        a3, _ = self.run_attn(wiki2, wiki_lens, cond=s3)
        c4, s4 = self.film4(c3, a3, inv, task, pos, wiki_attn)
        a4, _ = self.run_attn(wiki2, wiki_lens, cond=s4)
        c5, s5 = self.film5(c4+c3, a4, inv, task, pos, wiki_attn)
        conv_out = c5.max(3)[0].max(2)[0]  # pool over spatial dimensions
        flat = conv_out.view(T * B, -1)  # (T*B, -1)
        return self.fc(flat)  # (T*B, drep)