in models/modality_projector.py [0:0]
def pixel_shuffle(self, x):
bsz, seq, embed_dim = x.size()
seq_root = int(seq**0.5)
assert seq_root**2 == seq # Sequence length must be a perfect square for pixel shuffle
assert seq_root % self.scale_factor == 0 # Sequence root must be divisible by scale factor
height = width = seq_root
x = x.view(bsz, height, width, embed_dim)
h_out = height // self.scale_factor
w_out = width // self.scale_factor
x = x.reshape(bsz, h_out, self.scale_factor, w_out, self.scale_factor, embed_dim)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.reshape(bsz, h_out * w_out, embed_dim * self.scale_factor**2)
return x