in timm/models/volo.py [0:0]
def forward_train(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Tuple[int, int, int, int]]]:
"""Forward pass for training with mix token support.
Args:
x: Input tensor of shape (B, C, H, W).
Returns:
If training with mix_token: tuple of (class_token, aux_tokens, bbox).
Otherwise: class_token tensor.
"""
""" A separate forward fn for training with mix_token (if a train script supports).
Combining multiple modes in as single forward with different return types is torchscript hell.
"""
x = self.patch_embed(x)
x = x.permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
# mix token, see token labeling for details.
if self.mix_token and self.training:
lam = torch.distributions.Beta(self.beta, self.beta).sample()
patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale
bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
temp_x = x.clone()
sbbx1, sbby1 = self.pooling_scale * bbx1, self.pooling_scale * bby1
sbbx2, sbby2 = self.pooling_scale * bbx2, self.pooling_scale * bby2
temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
x = temp_x
else:
bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
# step2: tokens learning in the two stages
x = self.forward_tokens(x)
# step3: post network, apply class attention or not
if self.post_network is not None:
x = self.forward_cls(x)
x = self.norm(x)
if self.global_pool == 'avg':
x_cls = x.mean(dim=1)
elif self.global_pool == 'token':
x_cls = x[:, 0]
else:
x_cls = x
if self.aux_head is None:
return x_cls
x_aux = self.aux_head(x[:, 1:]) # generate classes in all feature tokens, see token labeling
if not self.training:
return x_cls + 0.5 * x_aux.max(1)[0]
if self.mix_token and self.training: # reverse "mix token", see token labeling for details.
x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
temp_x = x_aux.clone()
temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
x_aux = temp_x
x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
# return these: 1. class token, 2. classes from all feature tokens, 3. bounding box
return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)