in timm/models/coat.py [0:0]
def forward_features(self, x0):
B = x0.shape[0]
# Serial blocks 1.
x1 = self.patch_embed1(x0)
H1, W1 = self.patch_embed1.grid_size
x1 = insert_cls(x1, self.cls_token1)
for blk in self.serial_blocks1:
x1 = blk(x1, size=(H1, W1))
x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 2.
x2 = self.patch_embed2(x1_nocls)
H2, W2 = self.patch_embed2.grid_size
x2 = insert_cls(x2, self.cls_token2)
for blk in self.serial_blocks2:
x2 = blk(x2, size=(H2, W2))
x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 3.
x3 = self.patch_embed3(x2_nocls)
H3, W3 = self.patch_embed3.grid_size
x3 = insert_cls(x3, self.cls_token3)
for blk in self.serial_blocks3:
x3 = blk(x3, size=(H3, W3))
x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 4.
x4 = self.patch_embed4(x3_nocls)
H4, W4 = self.patch_embed4.grid_size
x4 = insert_cls(x4, self.cls_token4)
for blk in self.serial_blocks4:
x4 = blk(x4, size=(H4, W4))
x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
# Only serial blocks: Early return.
if self.parallel_blocks is None:
if not torch.jit.is_scripting() and self.return_interm_layers:
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
feat_out = {}
if 'x1_nocls' in self.out_features:
feat_out['x1_nocls'] = x1_nocls
if 'x2_nocls' in self.out_features:
feat_out['x2_nocls'] = x2_nocls
if 'x3_nocls' in self.out_features:
feat_out['x3_nocls'] = x3_nocls
if 'x4_nocls' in self.out_features:
feat_out['x4_nocls'] = x4_nocls
return feat_out
else:
# Return features for classification.
x4 = self.norm4(x4)
return x4
# Parallel blocks.
for blk in self.parallel_blocks:
x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
if not torch.jit.is_scripting() and self.return_interm_layers:
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
feat_out = {}
if 'x1_nocls' in self.out_features:
x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
feat_out['x1_nocls'] = x1_nocls
if 'x2_nocls' in self.out_features:
x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
feat_out['x2_nocls'] = x2_nocls
if 'x3_nocls' in self.out_features:
x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
feat_out['x3_nocls'] = x3_nocls
if 'x4_nocls' in self.out_features:
x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
feat_out['x4_nocls'] = x4_nocls
return feat_out
else:
x2 = self.norm2(x2)
x3 = self.norm3(x3)
x4 = self.norm4(x4)
return [x2, x3, x4]