in optimum/habana/transformers/gradient_checkpointing.py [0:0]
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
" with .grad() or passing an `inputs` parameter to .backward()."
" To resolve this error, you can either set use_reentrant=False,"
" or call .backward() without passing the `inputs` argument."
)
htcore.mark_step()
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
rng_devices = ctx.fwd_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_device_in_fwd:
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))
with (
torch.enable_grad(),
torch.autocast(**ctx.hpu_autocast_kwargs),
torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs),
):
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError("none of output has requires_grad=True, this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
return (None, None) + grads