in functorch/csrc/DynamicLayer.cpp [540:649]
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto cur_level = getDynamicLayerStack().back().layerId();
auto cur_key = getDynamicLayerStack().back().key();
optional<bool> prev_grad_mode = getDynamicLayerStack().back().prevGradMode();
if (cur_key == DispatchKey::Autograd) {
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
}
auto unwrap = [&](const Tensor& tensor) {
if (!tensor.defined()) {
return tensor;
}
auto* maybe_tensor_wrapper = maybeGetTensorWrapper(tensor);
if (!maybe_tensor_wrapper) {
return tensor;
}
auto tensor_wrapper_level = maybe_tensor_wrapper->level().value();
TORCH_INTERNAL_ASSERT(tensor_wrapper_level <= cur_level);
if (tensor_wrapper_level == cur_level) {
return maybe_tensor_wrapper->value();
}
return tensor;
};
auto wrap = [&](const Tensor& tensor) {
if (!tensor.defined()) {
return tensor;
}
if (cur_level == 1) {
return tensor;
}
// if (c10::show_dispatch_trace_enabled()) {
// std::cout << "wrap " << cur_level << std::endl;
// }
return makeTensorWrapper(tensor, cur_level);
};
// TODO: we only need to do the following (marked with !) on in-place functions
// that modify sizes or strides. There aren't many of them.
// If autograd dispatch key:
// 1. (!) Put a copy of all of the args onto the stack
// 2. Unwrap all the args in the copy set
// 3. Call the operator
// 4. Wrap the output
// 5. (!) refreshMetadata for all the args in the original set
// 6. (!) Pop those args off.
// Step 1 & 2
if (cur_key == DispatchKey::Autograd) {
auto args_size = op.schema().arguments().size();
// Step 1
auto front = stack->size() - args_size;
for (const auto arg_idx : c10::irange(0, args_size)) {
stack->push_back((*stack)[front + arg_idx]);
}
// Step 2
foreachTensorInplace(*stack, stack->size() - args_size, stack->size(), unwrap);
}
// pop the top layer. Put it back on dtor.
WithoutTop guard;
// "reset exclude set"
// TODO: Still a problem with composabiilty and AutoNonVariableTypeGuard.
// Users cannot do torch.no_grad otherwise there will be problems.
SaveLocalDispatchKeySet save_guard;
auto keyset = c10::impl::PODLocalDispatchKeySet();
c10::impl::_force_tls_local_dispatch_key_set(keyset);
setDynamicLayerFrontBackKeysIncluded(true);
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
if (c10::show_dispatch_trace_enabled()) {
dump_local_tls();
}
#endif
// Re-dispatch
if (cur_key == DispatchKey::Autograd && *prev_grad_mode == false) {
// See NOTE [grad and vjp interaction with no_grad]
c10::AutoGradMode guard(*prev_grad_mode);
op.callBoxed(stack);
} else {
op.callBoxed(stack);
}
// Step 4, 5, 6
if (cur_key == DispatchKey::Autograd) {
// Step 4
auto ret_size = op.schema().returns().size();
foreachTensorInplace(*stack, stack->size() - ret_size, stack->size(), wrap);
// Step 5
auto args_size = op.schema().arguments().size();
auto args_front = stack->size() - args_size - ret_size;
for (const auto arg_idx : c10::irange(0, args_size)) {
auto& ivalue = (*stack)[args_front + arg_idx];
if (!ivalue.isTensor()) {
continue;
}
auto maybe_tensor_wrapper = maybeGetTensorWrapper(ivalue.toTensor());
if (!maybe_tensor_wrapper) {
continue;
}
maybe_tensor_wrapper->refreshMetadata();
}
// Step 6
stack->erase(stack->end() - (args_size + ret_size), stack->end() - ret_size);
}
}