in models/base.py [0:0]
def slide_inference(self, img, img_meta, rescale):
"""Inference by sliding-window with overlap."""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = img.size()
num_classes = self.num_classes
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
pad_img = crop_img.new_zeros(
(crop_img.size(0), crop_img.size(1), h_crop, w_crop))
pad_img[:, :, :y2 - y1, :x2 - x1] = crop_img
if len(self.encode_decode(pad_img)) != 1:
pad_seg_logit, _, _ = self.encode_decode(pad_img)
else:
pad_seg_logit = self.encode_decode(pad_img)
preds[:, :, y1:y2,
x1:x2] += pad_seg_logit[:, :, :y2 - y1, :x2 - x1]
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
preds = preds / count_mat
if rescale:
preds = resize(
preds,
size=img_meta[0]['ori_shape'][:2],
mode='bilinear',
align_corners=self.align_corners,
warning=False)
return preds