bool runImpl()

in torch/csrc/jit/runtime/interpreter.cpp [234:723]


  bool runImpl(Stack& stack) {
    // if we have never run before, then we might have to return the
    // stack when we suspend, record where it starts so we return the right
    // stack
    if (stack_start_ == -1) {
      TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs);
      stack_start_ = stack.size() - frames.back().function->n_inputs;
    } else {
      // during restarts, all of the stack is always our own, so we leave
      // nothing
      stack_start_ = 0;
    }

    TLSCurrentInterpreterGuard g(this);
    if (frames.back().pc == 0 && stack_start_ == 0) {
      checkAndStartRecordFunction(frames.back(), stack);
    }

#if defined(JIT_USE_COMPUTED_GOTO)
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
    static void* dispatch_table[] = {
#define DISPATCH_TABLE_ENTRY(op, _) &&label_##op,
        FORALL_OPCODES(DISPATCH_TABLE_ENTRY)
#undef DISPATCH_TABLE_ENTRY
    };
#endif

    try {
      while (true) {
        Frame& frame = frames.back();
        Instruction inst = INST_FETCH(0);
        switch (inst.op) {
          case INST(ENTER): {
            INST_GUARD;
            const auto& obj = peek(stack, 0, 1);
            TORCH_INTERNAL_ASSERT(obj.isObject());
            entered_objects.push_back(obj);
          }
            INST_NEXT;
          case INST(EXIT): {
            INST_GUARD;
            auto obj = entered_objects.back().toObject();
            auto& f = obj->type()->getMethod("__exit__");
            push(stack, std::move(obj));
            entered_objects.pop_back();
            push(stack, IValue());
            push(stack, IValue());
            push(stack, IValue());
            callFunction(f, stack);
            continue;
          }
          case INST(OP): {
            INST_GUARD;
#ifndef NDEBUG
            size_t init_size = stack.size();
#endif
            frame.function->operator_table_[inst.X](stack);
#ifndef NDEBUG
            frame.function->assert_stack_size(inst.X, init_size, stack.size());
#endif
          }
            INST_NEXT;
          case INST(OPN): {
            INST_GUARD;
            stack.push_back(inst.N);
#ifndef NDEBUG
            size_t init_size = stack.size();
#endif
            frame.function->operator_table_[inst.X](stack);
#ifndef NDEBUG
            frame.function->assert_stack_size(inst.X, init_size, stack.size());
#endif
          }
            INST_NEXT;
          case INST(LOAD): {
            INST_GUARD;
            stack.emplace_back(reg(inst.X));
          }
            INST_NEXT;
          case INST(MOVE): {
            INST_GUARD;
            stack.emplace_back(std::move(reg(inst.X)));
          }
            INST_NEXT;
          case INST(STORE): {
            INST_GUARD;
            reg(inst.X) = pop(stack);
          }
            INST_NEXT;
          case INST(STOREN): {
            INST_GUARD;
            for (size_t i = inst.N; i > 0; --i) {
              reg(inst.X + i - 1) = pop(stack);
            }
          }
            INST_NEXT;
          case INST(DROP): {
            INST_GUARD;
            stack.pop_back();
          }
            INST_NEXT;
          case INST(DROPR): {
            INST_GUARD;
            reg(inst.X) = IValue();
          }
            INST_NEXT;
          case INST(LOADC): {
            INST_GUARD;
            stack.emplace_back(frame.function->constant_table_[inst.X]);
          }
            INST_NEXT;
          case INST(GET_ATTR): {
            INST_GUARD;
            const auto& userObj = stack.back().toObjectRef();
            stack.back() = userObj.getSlot(inst.X);
          }
            INST_NEXT;
          case INST(SET_ATTR): {
            INST_GUARD;
            auto v = pop(stack);
            auto& userObj = stack.back().toObjectRef();
            userObj.setSlot(inst.X, std::move(v));
            stack.pop_back();
          }
            INST_NEXT;
          case INST(JF): {
            INST_GUARD;
            if (pop(stack).toBool()) {
              inst = INST_FETCH(1);
            } else {
              inst = INST_FETCH(inst.X);
            }
          }
            INST_DISPATCH;
          case INST(JMP): {
            INST_GUARD;
            inst = INST_FETCH(inst.X);
          }
            INST_DISPATCH;
          case INST(LOOP): {
            INST_GUARD;
            // stack: iteration_count, max_iter, cond, loop_carried_deps...
            auto fr = stack.end() - (inst.N + 1);
            int64_t trip_count = fr[0].toInt();
            int64_t max_trip_count = fr[1].toInt();
            bool cond = fr[2].toBool();
            if (trip_count < max_trip_count && cond) {
              fr[2] = trip_count;
              fr[0] = trip_count + 1;
              inst = INST_FETCH(1);
            } else {
              size_t n_loop_carried = inst.N - 2;
              for (const auto i : c10::irange(n_loop_carried)) {
                fr[i] = std::move(fr[i + 3]);
              }
              drop(stack, 3); // iteration_count, max_iter, cond
              inst = INST_FETCH(inst.X);
            }
          }
            INST_DISPATCH;
          case INST(CALL): {
            INST_GUARD;
            Function* fn = frame.function->function_table_[inst.X];
            callFunction(*fn, stack);
            continue;
          }
          case INST(INTERFACE_CALL): {
            INST_GUARD;
            // note the hash table lookup to find the function
            // this can be more optimized if necessary, caching parts
            // of the hashing computation or storing the offset when
            // the object is turned into an interface

            // consider passing
            // `frames.back().function->remaining_bailout_depth_` into
            // `get_executor().getPlanFor()` to propagate caller's depth
            // restrictions onto children while this strategy has a potential to
            // reduce the number of compilations for too dynamic callers we
            // might miss opportunities where a caller is dynamic but a callee
            // gets stable arguments
            Function& function =
                peek(stack, 0, inst.N)
                    .toObject()
                    ->type()
                    ->getMethod(
                        frame.function->constant_table_[inst.X].toStringRef());
            callFunction(function, stack);
            continue;
          }
          case INST(RET): {
            if (frames.size() > 1) {
              leaveFrame();
              continue;
            }
            if (future_) {
              auto num_outputs = frames.back().function->n_outputs;
              if (num_outputs == 1) {
                future_->markCompleted(stack.back());
              } else {
                future_->markCompleted(
                    c10::ivalue::Tuple::create(jit::last(stack, num_outputs)));
              }
            }
            // destroy the last frame and call RecordFunction's end callbacks
            leaveFrame();
            return false;
          }
          case INST(WAIT): {
            INST_GUARD;
            auto future = stack.back().toFuture();
            if (!future->completed()) {
              getOrCreateFuture();

              // callback needs to be a struct rather than a lambda so that
              // we can move the stack to the other thread
              struct Callback {
                Callback(
                    c10::intrusive_ptr<InterpreterStateImpl> state,
                    Stack stack)
                    : stateImpl_(std::move(state)),
                      state_(stateImpl_),
                      stack_(std::move(stack)) {
                  dist_autograd_context_id_ = getDistAutogradContextId();
                  state_ = InterpreterState(stateImpl_);
                }
                void operator()(c10::ivalue::Future& /* unused */) {
                  stateImpl_->taskLauncher_(InterpreterContinuation(
                      state_,
                      std::move(stack_),
                      dist_autograd_context_id_,
                      std::move(tls_state_)));
                }

               private:
                c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
                InterpreterState state_;
                Stack stack_;
                int64_t dist_autograd_context_id_;
                // preserve the original ThreadLocalState
                at::ThreadLocalState tls_state_;
              };

              // we are suspending, so we need to reset the stack to where we
              // started if it started empty, except for the inputs we can avoid
              // a true copy by swapping, which leaves the original stack empty.
              Stack copied;
              if (stack_start_ == 0) {
                copied.swap(stack);
              } else {
                copied.insert(
                    copied.begin(),
                    std::make_move_iterator(stack.begin() + stack_start_),
                    std::make_move_iterator(stack.end()));
                stack.resize(stack_start_);
              }
              // save pc into the frame so we continue here when restored
              future->addCallback(
                  Callback(intrusive_from_this(), std::move(copied)));

              return true;
            }
            stack.pop_back();
            stack.emplace_back(future->value());
          }
            INST_NEXT;
          case INST(PROFILE_OP): {
            INST_GUARD;
            auto& frame_id_ref = frame.id;
            if (!frame_id_ref.has_value()) {
              frame_id_ref = Frame::genId();
            }
            const auto& callback =
                frame.function->profile_function_table_[inst.X];
            push(stack, c10::IValue{static_cast<int64_t>(*frame_id_ref)});
            callback(stack);
          }
            INST_NEXT;
          case INST(FAIL_GUARD): {
            INST_GUARD;
            // patch FAIL_GUARD back to GUARD
            GRAPH_DEBUG(
                "Bailout ", inst.X, " triggered via bailout_requests_!");
            frame.function->instructions_[frame.pc].op = GUARD;
            push(stack, false);
          }
            INST_NEXT;
          case INST(TYPECHECK): {
            INST_GUARD;
            int num_inputs = inst.N, i = 0;
            // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
            TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
            // Check every input's shape against profiled (expected) shape.
            for (i = 0; i < num_inputs; i++) {
              auto& input = peek(stack, i, num_inputs);
              auto& t = input.toTensor();
              const TypePtr& expected = frame.function->type_table_[inst.X + i];
              auto* expected_type = expected->castRaw<TensorType>();
              if (t.defined() && !expected_type->matchTensor(t)) {
                push(stack, false);
                break;
              }
            }
            if (i == num_inputs) {
              push(stack, true);
            }
          }
            INST_NEXT;
          case INST(GUARD): {
            INST_GUARD;
            if (!stack.back().isTensor()) {
              // stack.back() is an Uninitialized IValue and this is a guard
              // on a block output. Uninitialized IValues are never used
              // so it's safe to pass this guard check
              push(stack, true);
            } else {
              auto& t = stack.back().toTensor();
              const TypePtr& expected = frame.function->type_table_[inst.X];
              auto* expected_type = expected->castRaw<TensorType>();
              if (t.defined() &&
                  !frames.back().symbols2dims.bindSymbolicShapes(
                      t.sizes(), expected_type->symbolic_sizes())) {
                push(stack, false);
              } else {
                push(stack, expected_type->matchTensor(t));
              }
            }
          }
            INST_NEXT;
          case INST(TAIL_CALL): {
            INST_GUARD;
            GRAPH_DEBUG("running TAIL_CALL for ", inst.X);
            frame.function->function_table_[inst.X]->ensure_defined();
            size_t remaining_bailout_depth =
                frame.function->remaining_bailout_depth_ > 0
                ? frame.function->remaining_bailout_depth_ - 1
                : 0;
            auto& f = *frame.function->function_table_[inst.X];
            size_t num_inputs = f.num_inputs();
            size_t base_pointer = frame.base_pointer;
            TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs);
            size_t inputs_start = stack.size() - num_inputs;
            for (const auto i : c10::irange(num_inputs)) {
              stack.at(base_pointer + i) =
                  std::move(stack.at(inputs_start + i));
            }
            stack.resize(base_pointer + num_inputs);
            leaveFrame();

            callFunction(f, stack, remaining_bailout_depth, false);
            continue;
          }
          case INST(LIST_UNPACK): {
            INST_GUARD;
            listUnpack(stack, inst.X);
          }
            INST_NEXT;
          case INST(TUPLE_CONSTRUCT): {
            INST_GUARD;
            tupleConstruct(stack, inst.X);
          }
            INST_NEXT;
          case INST(TUPLE_SLICE): {
            INST_GUARD;
            tupleSlice(stack, inst.X, inst.X + inst.N);
          }
            INST_NEXT;
          case INST(NAMED_TUPLE_CONSTRUCT): {
            INST_GUARD;
            namedTupleConstruct(
                stack,
                frame.function->type_table_[inst.X]->expect<TupleType>(),
                inst.N);
          }
            INST_NEXT;
          case INST(LIST_CONSTRUCT): {
            INST_GUARD;
            const auto& type =
                frame.function->type_table_[inst.X]->expectRef<ListType>();
            listConstruct(stack, type, inst.N);
          }
            INST_NEXT;
          case INST(DICT_CONSTRUCT): {
            INST_GUARD;
            const auto& type =
                frame.function->type_table_[inst.X]->expectRef<DictType>();
            dictConstruct(stack, type, inst.N);
          }
            INST_NEXT;
          case INST(CREATE_OBJECT): {
            INST_GUARD;
            auto type =
                frame.function->type_table_[inst.X]->expect<ClassType>();
            createObject(stack, type);
          }
            INST_NEXT;
          case INST(ISINSTANCE): {
            INST_GUARD;
            at::ArrayRef<TypePtr> types(
                &frame.function->type_table_[inst.X],
                &frame.function->type_table_[inst.X] + inst.N);
            isinstance(stack, types);
          }
            INST_NEXT;
          case INST(FORK): {
            INST_GUARD;
            // Move inputs to a separate stack
            auto& forked_fn =
                toGraphFunction(*frame.function->function_table_[inst.X]);
            InterpreterState forked_interpreter(
                forked_fn.get_executor()
                    .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts())
                    .code,
                taskLauncher_);
            InterpreterContinuation continuation(
                forked_interpreter,
                Stack(stack.end() - inst.N, stack.end()),
                getDistAutogradContextId());
            drop(stack, inst.N);
            push(stack, forked_interpreter.getFuture());
            taskLauncher_(std::move(continuation));
          }
            INST_NEXT;
          case INST(WARN): {
            INST_GUARD;
            // Keeps track of which WARN instruction has been executed before,
            // we only want to execute each WARN once to match default Python
            // warning behavior.
            bool need_warn = true;
            if (inst.X != -1) {
              need_warn = warned_nodes_.insert(inst.X);
            }

            Node* node =
                frames.back().function->instructions_source_.at(frame.pc);
            auto range = node->sourceRange().source();
            if (range->filename()) {
              drop(stack, 1);
              const auto& msg = stack.back().toStringRef();
              if (need_warn) {
                auto line = range->starting_line_no() +
                    range->lineno_for_offset(node->sourceRange().start());
                c10::SourceLocation location{
                    "", range->filename()->c_str(), uint32_t(line)};
                // Sends the warning to the warning handler with the
                // "verbatim" flag. This flag ensures the warning handler
                // will print the exception as configured.
                c10::Warning::warn(location, msg, /*verbatim=*/true);
              }
              stack.pop_back();
            } else {
              const auto& msg = stack.back().toStringRef();
              if (need_warn) {
                TORCH_WARN(msg);
              }
              stack.pop_back();
            }
          }
            INST_NEXT;
        }
      }
    } catch (std::exception& e) {
      for (auto it = entered_objects.rbegin(), end = entered_objects.rend();
           it != end;
           ++it) {
        auto& f = it->toObject()->type()->getMethod("__exit__");
        Stack stack;
        push(stack, *it);
        push(stack, IValue());
        push(stack, IValue());
        push(stack, IValue());
        try {
          f.run(stack);
        } catch (std::exception& _) {
          // TODO(T98048876): Handle `_` correctly.
        }
      }
      if (FLAGS_torch_jit_enable_rethrow_caught_exception) {
        if (future_) {
          future_->setError(std::current_exception());
          return false;
        }
        throw;
      }
      bool is_jit_exception = dynamic_cast<JITException*>(&e);
      // Janky af.  See https://github.com/pytorch/pytorch/issues/54612
      auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
      handleError(ExceptionMessage(e), is_jit_exception, not_implemented_error);
      return false;
    }
  }