in cm/karras_diffusion.py [0:0]
def average_image_patches(x):
x_flatten = (
x.reshape(-1, 3, image_size, image_size)
.reshape(
-1,
3,
image_size // patch_size,
patch_size,
image_size // patch_size,
patch_size,
)
.permute(0, 1, 2, 4, 3, 5)
.reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
)
x_flatten[..., :] = x_flatten.mean(dim=-1, keepdim=True)
return (
x_flatten.reshape(
-1,
3,
image_size // patch_size,
image_size // patch_size,
patch_size,
patch_size,
)
.permute(0, 1, 2, 4, 3, 5)
.reshape(-1, 3, image_size, image_size)
)