def decode()

in src/model.py [0:0]


    def decode(self, enc_outputs, y):
        bs = enc_outputs[0].size(0)
        assert len(enc_outputs) == self.n_layers + 1
        assert y.size() == (bs, self.n_attr)

        dec_outputs = [enc_outputs[-1]]
        y = y.unsqueeze(2).unsqueeze(3)
        for i, layer in enumerate(self.dec_layers):
            size = dec_outputs[-1].size(2)
            # attributes
            input = [dec_outputs[-1], y.expand(bs, self.n_attr, size, size)]
            # skip connection
            if 0 < i <= self.n_skip:
                input.append(enc_outputs[-1 - i])
            input = torch.cat(input, 1)
            dec_outputs.append(layer(input))

        assert len(dec_outputs) == self.n_layers + 1
        assert dec_outputs[-1].size() == (bs, self.img_fm, self.img_sz, self.img_sz)
        return dec_outputs