in stylegan2_ada_pytorch/torch_utils/ops/bias_act.py [0:0]
def _bias_act_cuda(dim=1, act="linear", alpha=None, gain=None, clamp=None):
"""Fast CUDA implementation of `bias_act()` using custom ops.
"""
# Parse arguments.
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)
# Lookup from cache.
key = (dim, act, alpha, gain, clamp)
if key in _bias_act_cuda_cache:
return _bias_act_cuda_cache[key]
# Forward op.
class BiasActCuda(torch.autograd.Function):
@staticmethod
def forward(ctx, x, b): # pylint: disable=arguments-differ
ctx.memory_format = (
torch.channels_last
if x.ndim > 2 and x.stride()[1] == 1
else torch.contiguous_format
)
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor
y = x
if act != "linear" or gain != 1 or clamp >= 0 or b is not _null_tensor:
y = _plugin.bias_act(
x,
b,
_null_tensor,
_null_tensor,
_null_tensor,
0,
dim,
spec.cuda_idx,
alpha,
gain,
clamp,
)
ctx.save_for_backward(
x if "x" in spec.ref or spec.has_2nd_grad else _null_tensor,
b if "x" in spec.ref or spec.has_2nd_grad else _null_tensor,
y if "y" in spec.ref else _null_tensor,
)
return y
@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
db = None
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dx = dy
if act != "linear" or gain != 1 or clamp >= 0:
dx = BiasActCudaGrad.apply(dy, x, b, y)
if ctx.needs_input_grad[1]:
db = dx.sum([i for i in range(dx.ndim) if i != dim])
return dx, db
# Backward op.
class BiasActCudaGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
ctx.memory_format = (
torch.channels_last
if dy.ndim > 2 and dy.stride()[1] == 1
else torch.contiguous_format
)
dx = _plugin.bias_act(
dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp
)
ctx.save_for_backward(dy if spec.has_2nd_grad else _null_tensor, x, b, y)
return dx
@staticmethod
def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
d_x = None
d_b = None
d_y = None
if ctx.needs_input_grad[0]:
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
if spec.has_2nd_grad and (
ctx.needs_input_grad[1] or ctx.needs_input_grad[2]
):
d_x = _plugin.bias_act(
d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp
)
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
return d_dy, d_x, d_b, d_y
# Add to cache.
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda