in src/resnet50.py [0:0]
def forward(self, inputs):
if not isinstance(inputs, list):
inputs = [inputs]
idx_crops = torch.cumsum(torch.unique_consecutive(
torch.tensor([inp.shape[-1] for inp in inputs]),
return_counts=True,
)[1], 0)
start_idx = 0
for end_idx in idx_crops:
_out = self.forward_backbone(torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True))
if start_idx == 0:
output = _out
else:
output = torch.cat((output, _out))
start_idx = end_idx
return self.forward_head(output)