in patchconvnet_models.py [0:0]
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
for i , blk in enumerate(self.blocks):
x = blk(x)
for i , blk in enumerate(self.blocks_token_only):
cls_tokens = blk(x,cls_tokens)
x = torch.cat((cls_tokens, x), dim=1)
x = self.norm(x)
if not self.multiclass:
return x[:, 0]
else:
return x[:, :self.num_classes].reshape(B,self.num_classes,-1)