def forward()

in src/wide_resnet.py [0:0]


    def forward(self, inputs, return_before_head=False):
        if not isinstance(inputs, list):
            inputs = [inputs]
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in inputs]),
            return_counts=True,
        )[1], 0)
        start_idx = 0
        for end_idx in idx_crops:
            _h = self._forward_backbone(torch.cat(inputs[start_idx:end_idx]))
            _z = self._forward_head(_h)
            if start_idx == 0:
                h, z = _h, _z
            else:
                h, z = torch.cat((h, _h)), torch.cat((z, _z))
            start_idx = end_idx

        if return_before_head:
            return h, z

        return z