def _bias_act_cuda()

in stylegan2_ada_pytorch/torch_utils/ops/bias_act.py [0:0]


def _bias_act_cuda(dim=1, act="linear", alpha=None, gain=None, clamp=None):
    """Fast CUDA implementation of `bias_act()` using custom ops.
    """
    # Parse arguments.
    assert clamp is None or clamp >= 0
    spec = activation_funcs[act]
    alpha = float(alpha if alpha is not None else spec.def_alpha)
    gain = float(gain if gain is not None else spec.def_gain)
    clamp = float(clamp if clamp is not None else -1)

    # Lookup from cache.
    key = (dim, act, alpha, gain, clamp)
    if key in _bias_act_cuda_cache:
        return _bias_act_cuda_cache[key]

    # Forward op.
    class BiasActCuda(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x, b):  # pylint: disable=arguments-differ
            ctx.memory_format = (
                torch.channels_last
                if x.ndim > 2 and x.stride()[1] == 1
                else torch.contiguous_format
            )
            x = x.contiguous(memory_format=ctx.memory_format)
            b = b.contiguous() if b is not None else _null_tensor
            y = x
            if act != "linear" or gain != 1 or clamp >= 0 or b is not _null_tensor:
                y = _plugin.bias_act(
                    x,
                    b,
                    _null_tensor,
                    _null_tensor,
                    _null_tensor,
                    0,
                    dim,
                    spec.cuda_idx,
                    alpha,
                    gain,
                    clamp,
                )
            ctx.save_for_backward(
                x if "x" in spec.ref or spec.has_2nd_grad else _null_tensor,
                b if "x" in spec.ref or spec.has_2nd_grad else _null_tensor,
                y if "y" in spec.ref else _null_tensor,
            )
            return y

        @staticmethod
        def backward(ctx, dy):  # pylint: disable=arguments-differ
            dy = dy.contiguous(memory_format=ctx.memory_format)
            x, b, y = ctx.saved_tensors
            dx = None
            db = None

            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
                dx = dy
                if act != "linear" or gain != 1 or clamp >= 0:
                    dx = BiasActCudaGrad.apply(dy, x, b, y)

            if ctx.needs_input_grad[1]:
                db = dx.sum([i for i in range(dx.ndim) if i != dim])

            return dx, db

    # Backward op.
    class BiasActCudaGrad(torch.autograd.Function):
        @staticmethod
        def forward(ctx, dy, x, b, y):  # pylint: disable=arguments-differ
            ctx.memory_format = (
                torch.channels_last
                if dy.ndim > 2 and dy.stride()[1] == 1
                else torch.contiguous_format
            )
            dx = _plugin.bias_act(
                dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp
            )
            ctx.save_for_backward(dy if spec.has_2nd_grad else _null_tensor, x, b, y)
            return dx

        @staticmethod
        def backward(ctx, d_dx):  # pylint: disable=arguments-differ
            d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
            dy, x, b, y = ctx.saved_tensors
            d_dy = None
            d_x = None
            d_b = None
            d_y = None

            if ctx.needs_input_grad[0]:
                d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)

            if spec.has_2nd_grad and (
                ctx.needs_input_grad[1] or ctx.needs_input_grad[2]
            ):
                d_x = _plugin.bias_act(
                    d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp
                )

            if spec.has_2nd_grad and ctx.needs_input_grad[2]:
                d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])

            return d_dy, d_x, d_b, d_y

    # Add to cache.
    _bias_act_cuda_cache[key] = BiasActCuda
    return BiasActCuda