def forward_train()

in timm/models/volo.py [0:0]


    def forward_train(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Tuple[int, int, int, int]]]:
        """Forward pass for training with mix token support.

        Args:
            x: Input tensor of shape (B, C, H, W).

        Returns:
            If training with mix_token: tuple of (class_token, aux_tokens, bbox).
            Otherwise: class_token tensor.
        """
        """ A separate forward fn for training with mix_token (if a train script supports).
        Combining multiple modes in as single forward with different return types is torchscript hell.
        """
        x = self.patch_embed(x)
        x = x.permute(0, 2, 3, 1)  # B,C,H,W-> B,H,W,C

        # mix token, see token labeling for details.
        if self.mix_token and self.training:
            lam = torch.distributions.Beta(self.beta, self.beta).sample()
            patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale
            bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
            temp_x = x.clone()
            sbbx1, sbby1 = self.pooling_scale * bbx1, self.pooling_scale * bby1
            sbbx2, sbby2 = self.pooling_scale * bbx2, self.pooling_scale * bby2
            temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
            x = temp_x
        else:
            bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0

        # step2: tokens learning in the two stages
        x = self.forward_tokens(x)

        # step3: post network, apply class attention or not
        if self.post_network is not None:
            x = self.forward_cls(x)
        x = self.norm(x)

        if self.global_pool == 'avg':
            x_cls = x.mean(dim=1)
        elif self.global_pool == 'token':
            x_cls = x[:, 0]
        else:
            x_cls = x

        if self.aux_head is None:
            return x_cls

        x_aux = self.aux_head(x[:, 1:])  # generate classes in all feature tokens, see token labeling
        if not self.training:
            return x_cls + 0.5 * x_aux.max(1)[0]

        if self.mix_token and self.training:  # reverse "mix token", see token labeling for details.
            x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
            temp_x = x_aux.clone()
            temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
            x_aux = temp_x
            x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])

        # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box
        return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)