in muse/modeling_transformer.py [0:0]
def forward(self, x, x_skip, **kwargs):
for res_block in self.res_blocks:
# pop res hidden states
res_hidden_states = x_skip[-1]
x_skip = x_skip[:-1]
x = torch.cat([x, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x)
else:
x = res_block(x)
if self.add_upsample:
if x.shape[0] >= 64:
x = x.contiguous()
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.upsample_conv(x)
return x