in cm/karras_diffusion.py [0:0]
def replacement(x0, x1):
x0_flatten = (
x0.reshape(-1, 3, image_size, image_size)
.reshape(
-1,
3,
image_size // patch_size,
patch_size,
image_size // patch_size,
patch_size,
)
.permute(0, 1, 2, 4, 3, 5)
.reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
)
x1_flatten = (
x1.reshape(-1, 3, image_size, image_size)
.reshape(
-1,
3,
image_size // patch_size,
patch_size,
image_size // patch_size,
patch_size,
)
.permute(0, 1, 2, 4, 3, 5)
.reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2)
)
x0 = th.einsum("bcnd,de->bcne", x0_flatten, Q)
x1 = th.einsum("bcnd,de->bcne", x1_flatten, Q)
x_mix = x0.new_zeros(x0.shape)
x_mix[..., 0] = x0[..., 0]
x_mix[..., 1:] = x1[..., 1:]
x_mix = th.einsum("bcne,de->bcnd", x_mix, Q)
x_mix = (
x_mix.reshape(
-1,
3,
image_size // patch_size,
image_size // patch_size,
patch_size,
patch_size,
)
.permute(0, 1, 2, 4, 3, 5)
.reshape(-1, 3, image_size, image_size)
)
return x_mix