in src/controlnet_aux/zoe/zoedepth/models/depth_model.py [0:0]
def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor:
"""
Inference interface for the model with padding augmentation
Padding augmentation fixes the boundary artifacts in the output depth map.
Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image.
This augmentation pads the input image and crops the prediction back to the original size / view.
Note: This augmentation is not required for the models trained with 'avoid_boundary'=True.
Args:
x (torch.Tensor): input tensor of shape (b, c, h, w)
pad_input (bool, optional): whether to pad the input or not. Defaults to True.
fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3.
fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3.
upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'.
padding_mode (str, optional): padding mode. Defaults to "reflect".
Returns:
torch.Tensor: output tensor of shape (b, 1, h, w)
"""
# assert x is nchw and c = 3
assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim())
assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1])
if pad_input:
assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0"
pad_h = int(np.sqrt(x.shape[2]/2) * fh)
pad_w = int(np.sqrt(x.shape[3]/2) * fw)
padding = [pad_w, pad_w]
if pad_h > 0:
padding += [pad_h, pad_h]
x = F.pad(x, padding, mode=padding_mode, **kwargs)
out = self._infer(x)
if out.shape[-2:] != x.shape[-2:]:
out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False)
if pad_input:
# crop to the original size, handling the case where pad_h and pad_w is 0
if pad_h > 0:
out = out[:, :, pad_h:-pad_h,:]
if pad_w > 0:
out = out[:, :, :, pad_w:-pad_w]
return out