in src/resnet.py [0:0]
def forward(self, inputs, return_before_head=False):
if not isinstance(inputs, list):
inputs = [inputs]
idx_crops = torch.cumsum(torch.unique_consecutive(
torch.tensor([inp.shape[-1] for inp in inputs]),
return_counts=True,
)[1], 0)
start_idx = 0
for end_idx in idx_crops:
_h = self._forward_backbone(torch.cat(inputs[start_idx:end_idx]))
_z = self._forward_head(_h)
if start_idx == 0:
h, z = _h, _z
else:
h, z = torch.cat((h, _h)), torch.cat((z, _z))
start_idx = end_idx
if return_before_head:
return h, z
return z