in timm/models/visformer.py [0:0]
def forward_features(self, x):
if self.stem is not None:
x = self.stem(x)
# stage 1
x = self.patch_embed1(x)
if self.pos_embed1 is not None:
x = self.pos_drop(x + self.pos_embed1)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stage1, x)
else:
x = self.stage1(x)
# stage 2
if self.patch_embed2 is not None:
x = self.patch_embed2(x)
if self.pos_embed2 is not None:
x = self.pos_drop(x + self.pos_embed2)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stage2, x)
else:
x = self.stage2(x)
# stage3
if self.patch_embed3 is not None:
x = self.patch_embed3(x)
if self.pos_embed3 is not None:
x = self.pos_drop(x + self.pos_embed3)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stage3, x)
else:
x = self.stage3(x)
x = self.norm(x)
return x