in torch/csrc/jit/runtime/interpreter.cpp [234:723]
bool runImpl(Stack& stack) {
// if we have never run before, then we might have to return the
// stack when we suspend, record where it starts so we return the right
// stack
if (stack_start_ == -1) {
TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs);
stack_start_ = stack.size() - frames.back().function->n_inputs;
} else {
// during restarts, all of the stack is always our own, so we leave
// nothing
stack_start_ = 0;
}
TLSCurrentInterpreterGuard g(this);
if (frames.back().pc == 0 && stack_start_ == 0) {
checkAndStartRecordFunction(frames.back(), stack);
}
#if defined(JIT_USE_COMPUTED_GOTO)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
static void* dispatch_table[] = {
#define DISPATCH_TABLE_ENTRY(op, _) &&label_##op,
FORALL_OPCODES(DISPATCH_TABLE_ENTRY)
#undef DISPATCH_TABLE_ENTRY
};
#endif
try {
while (true) {
Frame& frame = frames.back();
Instruction inst = INST_FETCH(0);
switch (inst.op) {
case INST(ENTER): {
INST_GUARD;
const auto& obj = peek(stack, 0, 1);
TORCH_INTERNAL_ASSERT(obj.isObject());
entered_objects.push_back(obj);
}
INST_NEXT;
case INST(EXIT): {
INST_GUARD;
auto obj = entered_objects.back().toObject();
auto& f = obj->type()->getMethod("__exit__");
push(stack, std::move(obj));
entered_objects.pop_back();
push(stack, IValue());
push(stack, IValue());
push(stack, IValue());
callFunction(f, stack);
continue;
}
case INST(OP): {
INST_GUARD;
#ifndef NDEBUG
size_t init_size = stack.size();
#endif
frame.function->operator_table_[inst.X](stack);
#ifndef NDEBUG
frame.function->assert_stack_size(inst.X, init_size, stack.size());
#endif
}
INST_NEXT;
case INST(OPN): {
INST_GUARD;
stack.push_back(inst.N);
#ifndef NDEBUG
size_t init_size = stack.size();
#endif
frame.function->operator_table_[inst.X](stack);
#ifndef NDEBUG
frame.function->assert_stack_size(inst.X, init_size, stack.size());
#endif
}
INST_NEXT;
case INST(LOAD): {
INST_GUARD;
stack.emplace_back(reg(inst.X));
}
INST_NEXT;
case INST(MOVE): {
INST_GUARD;
stack.emplace_back(std::move(reg(inst.X)));
}
INST_NEXT;
case INST(STORE): {
INST_GUARD;
reg(inst.X) = pop(stack);
}
INST_NEXT;
case INST(STOREN): {
INST_GUARD;
for (size_t i = inst.N; i > 0; --i) {
reg(inst.X + i - 1) = pop(stack);
}
}
INST_NEXT;
case INST(DROP): {
INST_GUARD;
stack.pop_back();
}
INST_NEXT;
case INST(DROPR): {
INST_GUARD;
reg(inst.X) = IValue();
}
INST_NEXT;
case INST(LOADC): {
INST_GUARD;
stack.emplace_back(frame.function->constant_table_[inst.X]);
}
INST_NEXT;
case INST(GET_ATTR): {
INST_GUARD;
const auto& userObj = stack.back().toObjectRef();
stack.back() = userObj.getSlot(inst.X);
}
INST_NEXT;
case INST(SET_ATTR): {
INST_GUARD;
auto v = pop(stack);
auto& userObj = stack.back().toObjectRef();
userObj.setSlot(inst.X, std::move(v));
stack.pop_back();
}
INST_NEXT;
case INST(JF): {
INST_GUARD;
if (pop(stack).toBool()) {
inst = INST_FETCH(1);
} else {
inst = INST_FETCH(inst.X);
}
}
INST_DISPATCH;
case INST(JMP): {
INST_GUARD;
inst = INST_FETCH(inst.X);
}
INST_DISPATCH;
case INST(LOOP): {
INST_GUARD;
// stack: iteration_count, max_iter, cond, loop_carried_deps...
auto fr = stack.end() - (inst.N + 1);
int64_t trip_count = fr[0].toInt();
int64_t max_trip_count = fr[1].toInt();
bool cond = fr[2].toBool();
if (trip_count < max_trip_count && cond) {
fr[2] = trip_count;
fr[0] = trip_count + 1;
inst = INST_FETCH(1);
} else {
size_t n_loop_carried = inst.N - 2;
for (const auto i : c10::irange(n_loop_carried)) {
fr[i] = std::move(fr[i + 3]);
}
drop(stack, 3); // iteration_count, max_iter, cond
inst = INST_FETCH(inst.X);
}
}
INST_DISPATCH;
case INST(CALL): {
INST_GUARD;
Function* fn = frame.function->function_table_[inst.X];
callFunction(*fn, stack);
continue;
}
case INST(INTERFACE_CALL): {
INST_GUARD;
// note the hash table lookup to find the function
// this can be more optimized if necessary, caching parts
// of the hashing computation or storing the offset when
// the object is turned into an interface
// consider passing
// `frames.back().function->remaining_bailout_depth_` into
// `get_executor().getPlanFor()` to propagate caller's depth
// restrictions onto children while this strategy has a potential to
// reduce the number of compilations for too dynamic callers we
// might miss opportunities where a caller is dynamic but a callee
// gets stable arguments
Function& function =
peek(stack, 0, inst.N)
.toObject()
->type()
->getMethod(
frame.function->constant_table_[inst.X].toStringRef());
callFunction(function, stack);
continue;
}
case INST(RET): {
if (frames.size() > 1) {
leaveFrame();
continue;
}
if (future_) {
auto num_outputs = frames.back().function->n_outputs;
if (num_outputs == 1) {
future_->markCompleted(stack.back());
} else {
future_->markCompleted(
c10::ivalue::Tuple::create(jit::last(stack, num_outputs)));
}
}
// destroy the last frame and call RecordFunction's end callbacks
leaveFrame();
return false;
}
case INST(WAIT): {
INST_GUARD;
auto future = stack.back().toFuture();
if (!future->completed()) {
getOrCreateFuture();
// callback needs to be a struct rather than a lambda so that
// we can move the stack to the other thread
struct Callback {
Callback(
c10::intrusive_ptr<InterpreterStateImpl> state,
Stack stack)
: stateImpl_(std::move(state)),
state_(stateImpl_),
stack_(std::move(stack)) {
dist_autograd_context_id_ = getDistAutogradContextId();
state_ = InterpreterState(stateImpl_);
}
void operator()(c10::ivalue::Future& /* unused */) {
stateImpl_->taskLauncher_(InterpreterContinuation(
state_,
std::move(stack_),
dist_autograd_context_id_,
std::move(tls_state_)));
}
private:
c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
InterpreterState state_;
Stack stack_;
int64_t dist_autograd_context_id_;
// preserve the original ThreadLocalState
at::ThreadLocalState tls_state_;
};
// we are suspending, so we need to reset the stack to where we
// started if it started empty, except for the inputs we can avoid
// a true copy by swapping, which leaves the original stack empty.
Stack copied;
if (stack_start_ == 0) {
copied.swap(stack);
} else {
copied.insert(
copied.begin(),
std::make_move_iterator(stack.begin() + stack_start_),
std::make_move_iterator(stack.end()));
stack.resize(stack_start_);
}
// save pc into the frame so we continue here when restored
future->addCallback(
Callback(intrusive_from_this(), std::move(copied)));
return true;
}
stack.pop_back();
stack.emplace_back(future->value());
}
INST_NEXT;
case INST(PROFILE_OP): {
INST_GUARD;
auto& frame_id_ref = frame.id;
if (!frame_id_ref.has_value()) {
frame_id_ref = Frame::genId();
}
const auto& callback =
frame.function->profile_function_table_[inst.X];
push(stack, c10::IValue{static_cast<int64_t>(*frame_id_ref)});
callback(stack);
}
INST_NEXT;
case INST(FAIL_GUARD): {
INST_GUARD;
// patch FAIL_GUARD back to GUARD
GRAPH_DEBUG(
"Bailout ", inst.X, " triggered via bailout_requests_!");
frame.function->instructions_[frame.pc].op = GUARD;
push(stack, false);
}
INST_NEXT;
case INST(TYPECHECK): {
INST_GUARD;
int num_inputs = inst.N, i = 0;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
// Check every input's shape against profiled (expected) shape.
for (i = 0; i < num_inputs; i++) {
auto& input = peek(stack, i, num_inputs);
auto& t = input.toTensor();
const TypePtr& expected = frame.function->type_table_[inst.X + i];
auto* expected_type = expected->castRaw<TensorType>();
if (t.defined() && !expected_type->matchTensor(t)) {
push(stack, false);
break;
}
}
if (i == num_inputs) {
push(stack, true);
}
}
INST_NEXT;
case INST(GUARD): {
INST_GUARD;
if (!stack.back().isTensor()) {
// stack.back() is an Uninitialized IValue and this is a guard
// on a block output. Uninitialized IValues are never used
// so it's safe to pass this guard check
push(stack, true);
} else {
auto& t = stack.back().toTensor();
const TypePtr& expected = frame.function->type_table_[inst.X];
auto* expected_type = expected->castRaw<TensorType>();
if (t.defined() &&
!frames.back().symbols2dims.bindSymbolicShapes(
t.sizes(), expected_type->symbolic_sizes())) {
push(stack, false);
} else {
push(stack, expected_type->matchTensor(t));
}
}
}
INST_NEXT;
case INST(TAIL_CALL): {
INST_GUARD;
GRAPH_DEBUG("running TAIL_CALL for ", inst.X);
frame.function->function_table_[inst.X]->ensure_defined();
size_t remaining_bailout_depth =
frame.function->remaining_bailout_depth_ > 0
? frame.function->remaining_bailout_depth_ - 1
: 0;
auto& f = *frame.function->function_table_[inst.X];
size_t num_inputs = f.num_inputs();
size_t base_pointer = frame.base_pointer;
TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs);
size_t inputs_start = stack.size() - num_inputs;
for (const auto i : c10::irange(num_inputs)) {
stack.at(base_pointer + i) =
std::move(stack.at(inputs_start + i));
}
stack.resize(base_pointer + num_inputs);
leaveFrame();
callFunction(f, stack, remaining_bailout_depth, false);
continue;
}
case INST(LIST_UNPACK): {
INST_GUARD;
listUnpack(stack, inst.X);
}
INST_NEXT;
case INST(TUPLE_CONSTRUCT): {
INST_GUARD;
tupleConstruct(stack, inst.X);
}
INST_NEXT;
case INST(TUPLE_SLICE): {
INST_GUARD;
tupleSlice(stack, inst.X, inst.X + inst.N);
}
INST_NEXT;
case INST(NAMED_TUPLE_CONSTRUCT): {
INST_GUARD;
namedTupleConstruct(
stack,
frame.function->type_table_[inst.X]->expect<TupleType>(),
inst.N);
}
INST_NEXT;
case INST(LIST_CONSTRUCT): {
INST_GUARD;
const auto& type =
frame.function->type_table_[inst.X]->expectRef<ListType>();
listConstruct(stack, type, inst.N);
}
INST_NEXT;
case INST(DICT_CONSTRUCT): {
INST_GUARD;
const auto& type =
frame.function->type_table_[inst.X]->expectRef<DictType>();
dictConstruct(stack, type, inst.N);
}
INST_NEXT;
case INST(CREATE_OBJECT): {
INST_GUARD;
auto type =
frame.function->type_table_[inst.X]->expect<ClassType>();
createObject(stack, type);
}
INST_NEXT;
case INST(ISINSTANCE): {
INST_GUARD;
at::ArrayRef<TypePtr> types(
&frame.function->type_table_[inst.X],
&frame.function->type_table_[inst.X] + inst.N);
isinstance(stack, types);
}
INST_NEXT;
case INST(FORK): {
INST_GUARD;
// Move inputs to a separate stack
auto& forked_fn =
toGraphFunction(*frame.function->function_table_[inst.X]);
InterpreterState forked_interpreter(
forked_fn.get_executor()
.getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
.code,
taskLauncher_);
InterpreterContinuation continuation(
forked_interpreter,
Stack(stack.end() - inst.N, stack.end()),
getDistAutogradContextId());
drop(stack, inst.N);
push(stack, forked_interpreter.getFuture());
taskLauncher_(std::move(continuation));
}
INST_NEXT;
case INST(WARN): {
INST_GUARD;
// Keeps track of which WARN instruction has been executed before,
// we only want to execute each WARN once to match default Python
// warning behavior.
bool need_warn = true;
if (inst.X != -1) {
need_warn = warned_nodes_.insert(inst.X);
}
Node* node =
frames.back().function->instructions_source_.at(frame.pc);
auto range = node->sourceRange().source();
if (range->filename()) {
drop(stack, 1);
const auto& msg = stack.back().toStringRef();
if (need_warn) {
auto line = range->starting_line_no() +
range->lineno_for_offset(node->sourceRange().start());
c10::SourceLocation location{
"", range->filename()->c_str(), uint32_t(line)};
// Sends the warning to the warning handler with the
// "verbatim" flag. This flag ensures the warning handler
// will print the exception as configured.
c10::Warning::warn(location, msg, /*verbatim=*/true);
}
stack.pop_back();
} else {
const auto& msg = stack.back().toStringRef();
if (need_warn) {
TORCH_WARN(msg);
}
stack.pop_back();
}
}
INST_NEXT;
}
}
} catch (std::exception& e) {
for (auto it = entered_objects.rbegin(), end = entered_objects.rend();
it != end;
++it) {
auto& f = it->toObject()->type()->getMethod("__exit__");
Stack stack;
push(stack, *it);
push(stack, IValue());
push(stack, IValue());
push(stack, IValue());
try {
f.run(stack);
} catch (std::exception& _) {
// TODO(T98048876): Handle `_` correctly.
}
}
if (FLAGS_torch_jit_enable_rethrow_caught_exception) {
if (future_) {
future_->setError(std::current_exception());
return false;
}
throw;
}
bool is_jit_exception = dynamic_cast<JITException*>(&e);
// Janky af. See https://github.com/pytorch/pytorch/issues/54612
auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
handleError(ExceptionMessage(e), is_jit_exception, not_implemented_error);
return false;
}
}