in src/model.py [0:0]
def decode(self, enc_outputs, y):
bs = enc_outputs[0].size(0)
assert len(enc_outputs) == self.n_layers + 1
assert y.size() == (bs, self.n_attr)
dec_outputs = [enc_outputs[-1]]
y = y.unsqueeze(2).unsqueeze(3)
for i, layer in enumerate(self.dec_layers):
size = dec_outputs[-1].size(2)
# attributes
input = [dec_outputs[-1], y.expand(bs, self.n_attr, size, size)]
# skip connection
if 0 < i <= self.n_skip:
input.append(enc_outputs[-1 - i])
input = torch.cat(input, 1)
dec_outputs.append(layer(input))
assert len(dec_outputs) == self.n_layers + 1
assert dec_outputs[-1].size() == (bs, self.img_fm, self.img_sz, self.img_sz)
return dec_outputs