in model/paper_txt2pi.py [0:0]
def fuse(self, inputs, cell, inv, wiki, task):
T, B, H, W, demb = cell.size()
tb = torch.flatten(cell, 0, 1) # (T*B, H, W, 3*demb)
pos = inputs['rel_pos'].float().view(T*B, H, W, -1).transpose(1, 3)
wiki1, wiki2 = wiki
wiki_lens = inputs['wiki_len'].view(-1).long()
wiki_attn, _ = self.run_attn(wiki1, wiki_lens, cond=task)
c0 = tb.transpose(1, 3) # (T*B, demb, W, H)
s0 = self.c0_trans(torch.cat([c0, pos], dim=1).max(3)[0].max(2)[0])
a0, _ = self.run_attn(wiki2, wiki_lens, cond=s0)
c1, s1 = self.film1(c0, a0, inv, task, pos, wiki_attn)
a1, _ = self.run_attn(wiki2, wiki_lens, cond=s1)
c2, s2 = self.film2(c1, a1, inv, task, pos, wiki_attn)
a2, _ = self.run_attn(wiki2, wiki_lens, cond=s2)
c3, s3 = self.film3(c2, a2, inv, task, pos, wiki_attn)
a3, _ = self.run_attn(wiki2, wiki_lens, cond=s3)
c4, s4 = self.film4(c3, a3, inv, task, pos, wiki_attn)
a4, _ = self.run_attn(wiki2, wiki_lens, cond=s4)
c5, s5 = self.film5(c4+c3, a4, inv, task, pos, wiki_attn)
conv_out = c5.max(3)[0].max(2)[0] # pool over spatial dimensions
flat = conv_out.view(T * B, -1) # (T*B, -1)
return self.fc(flat) # (T*B, drep)