in glide_text2im/unet.py [0:0]
def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape
qkv = self.qkv(self.norm(x).view(b, c, -1))
if encoder_out is not None:
encoder_out = self.encoder_kv(encoder_out)
h = self.attention(qkv, encoder_out)
else:
h = self.attention(qkv)
h = self.proj_out(h)
return x + h.reshape(b, c, *spatial)