in models/gpvit/gpvit_adapter.py [0:0]
def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, patch_resolution,
mask_indices=None, cluster_size_ratio=None):
H, W = patch_resolution
x = x.contiguous()
COUNT_LATENCY = False
if COUNT_LATENCY:
timestamps = []
torch.cuda.synchronize()
timestamps.append(time.time())
x = self.injector(query=x,
reference_points=deform_inputs1[0],
feat=c,
spatial_shapes=deform_inputs1[1],
level_start_index=deform_inputs1[2])
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
x0 = x
x, patch_resolution = \
self.chunk_feat(x, mask_indices, patch_resolution, cluster_size_ratio)
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
if x.size(0) > 0:
for idx, blk in enumerate(blocks):
x = blk(x, patch_resolution)
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
x = self.recover_feat(x, mask_indices, x0)
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
c = self.extractor(query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H, W=W)
if self.extra_extractors is not None:
for extractor in self.extra_extractors:
c = extractor(query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H, W=W)
if COUNT_LATENCY:
torch.cuda.synchronize()
timestamps.append(time.time())
s = ''
for i in range(len(timestamps) - 1):
s += f'{(timestamps[i+1] - timestamps[i])*1000:.2f}ms\t'
print(s)
return x, c