def forward()

in models/gpvit/gpvit_adapter.py [0:0]


    def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, patch_resolution, 
                mask_indices=None, cluster_size_ratio=None):
        H, W = patch_resolution
        x = x.contiguous()

        COUNT_LATENCY = False
        if COUNT_LATENCY:
            timestamps = []
            torch.cuda.synchronize()
            timestamps.append(time.time())
        x = self.injector(query=x,
                          reference_points=deform_inputs1[0],
                          feat=c,
                          spatial_shapes=deform_inputs1[1],
                          level_start_index=deform_inputs1[2])
        if COUNT_LATENCY:
            torch.cuda.synchronize()
            timestamps.append(time.time())
        
        x0 = x
        x, patch_resolution = \
            self.chunk_feat(x, mask_indices, patch_resolution, cluster_size_ratio)
        if COUNT_LATENCY:
            torch.cuda.synchronize()
            timestamps.append(time.time())

        if x.size(0) > 0:
            for idx, blk in enumerate(blocks):
                x = blk(x, patch_resolution)
        if COUNT_LATENCY:
            torch.cuda.synchronize()
            timestamps.append(time.time())

        x = self.recover_feat(x, mask_indices, x0)
        if COUNT_LATENCY:
            torch.cuda.synchronize()
            timestamps.append(time.time())

        c = self.extractor(query=c,
                           reference_points=deform_inputs2[0],
                           feat=x,
                           spatial_shapes=deform_inputs2[1],
                           level_start_index=deform_inputs2[2],
                           H=H, W=W)
        if self.extra_extractors is not None:
            for extractor in self.extra_extractors:
                c = extractor(query=c,
                              reference_points=deform_inputs2[0],
                              feat=x,
                              spatial_shapes=deform_inputs2[1],
                              level_start_index=deform_inputs2[2],
                              H=H, W=W)
        if COUNT_LATENCY:
            torch.cuda.synchronize()
            timestamps.append(time.time())
            
            s = ''
            for i in range(len(timestamps) - 1):
                s += f'{(timestamps[i+1] - timestamps[i])*1000:.2f}ms\t'
            print(s)

        return x, c