def _upfirdn2d_cuda()

in stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.py [0:0]


def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
    """Fast CUDA implementation of `upfirdn2d()` using custom ops.
    """
    # Parse arguments.
    upx, upy = _parse_scaling(up)
    downx, downy = _parse_scaling(down)
    padx0, padx1, pady0, pady1 = _parse_padding(padding)

    # Lookup from cache.
    key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
    if key in _upfirdn2d_cuda_cache:
        return _upfirdn2d_cuda_cache[key]

    # Forward op.
    class Upfirdn2dCuda(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x, f):  # pylint: disable=arguments-differ
            assert isinstance(x, torch.Tensor) and x.ndim == 4
            if f is None:
                f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
            assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
            y = x
            if f.ndim == 2:
                y = _plugin.upfirdn2d(
                    y,
                    f,
                    upx,
                    upy,
                    downx,
                    downy,
                    padx0,
                    padx1,
                    pady0,
                    pady1,
                    flip_filter,
                    gain,
                )
            else:
                y = _plugin.upfirdn2d(
                    y,
                    f.unsqueeze(0),
                    upx,
                    1,
                    downx,
                    1,
                    padx0,
                    padx1,
                    0,
                    0,
                    flip_filter,
                    np.sqrt(gain),
                )
                y = _plugin.upfirdn2d(
                    y,
                    f.unsqueeze(1),
                    1,
                    upy,
                    1,
                    downy,
                    0,
                    0,
                    pady0,
                    pady1,
                    flip_filter,
                    np.sqrt(gain),
                )
            ctx.save_for_backward(f)
            ctx.x_shape = x.shape
            return y

        @staticmethod
        def backward(ctx, dy):  # pylint: disable=arguments-differ
            f, = ctx.saved_tensors
            _, _, ih, iw = ctx.x_shape
            _, _, oh, ow = dy.shape
            fw, fh = _get_filter_size(f)
            p = [
                fw - padx0 - 1,
                iw * upx - ow * downx + padx0 - upx + 1,
                fh - pady0 - 1,
                ih * upy - oh * downy + pady0 - upy + 1,
            ]
            dx = None
            df = None

            if ctx.needs_input_grad[0]:
                dx = _upfirdn2d_cuda(
                    up=down,
                    down=up,
                    padding=p,
                    flip_filter=(not flip_filter),
                    gain=gain,
                ).apply(dy, f)

            assert not ctx.needs_input_grad[1]
            return dx, df

    # Add to cache.
    _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
    return Upfirdn2dCuda