Status NeuronModel::compute()

in runtime/model.cc [420:799]


Status NeuronModel::compute(OpKernelContext* ctx, const NodeDef& node_def) {
  uint64 start_time = Env::Default()->NowMicros();
#define VLOG_TIME(msg) VLOG_TIME_BASE(start_time, 1, msg);

  std::vector<Tensor> input_tensors(ctx->num_inputs());
  for (auto idx = 0; idx < ctx->num_inputs(); ++idx) {
    input_tensors.at(idx) = ctx->input(idx);
  }

  // pre-initialize
  TF_RETURN_IF_ERROR(pre_initialize(node_def));

  SharedMemoryAllocator* shm_allocator =
      NeuronEngineManager::GetNeuronEngineManager().get_shm_allocator();
  const google::protobuf::Map<std::string, AttrValue>& attr = node_def.attr();
  AttrList& input_names = attr.at("input_names").list();
  AttrList& output_names = attr.at("output_names").list();
  AttrList& output_shapes = attr.at("output_shapes").list();
  TFNN_ASSERT((int)input_tensors.size() == input_names.s_size(),
              errors::InvalidArgument("incorrect number of input tensors"));

  // lambda for enabling shared memory
  auto UseShmForIO = [&](const std::vector<Tensor>& input_tensors) {
    bool use_shm = can_use_shm_ && shm_allocator->is_valid();
    for (const Tensor& tensor : input_tensors) {
      use_shm &= tensor.NumElements() != 0;
    }
    return use_shm;
  };

  // enable/disable dynamic batch size
  int64_t batch_size = UNINIT_BATCH_SIZE;
  int64_t k_batch_size = UNINIT_BATCH_SIZE;
  std::vector<bool> is_batch_inputs(input_tensors.size());
  std::vector<bool> is_batch_outputs(ctx->num_outputs());
  AttrList& input_batch_axis = attr.at("input_batch_axis").list();
  AttrList& output_batch_axis = attr.at("output_batch_axis").list();
  bool use_dynamic_batch_size = false;
  int64 input_copy_cost_per_unit = 0;
  if (allow_dynamic_batch_size_) {
    AttrList& input_shapes = attr.at("input_shapes").list();
    for (size_t idx = 0; idx < input_tensors.size(); ++idx) {
      bool is_batch_tensor = false;
      const Tensor& in_tensor = input_tensors.at(idx);
      TensorShape shape(in_tensor.shape());
      TensorShape k_shape(input_shapes.shape(idx));
      input_copy_cost_per_unit += k_shape.num_elements();
      if (TF_PREDICT_TRUE(0 == input_batch_axis.i(idx))) {
        TFNN_ASSERT(
            shape.dims() > 0,
            errors::InvalidArgument("no batch-dimension found on input tensor ",
                                    input_names.s(idx), " with shape ",
                                    shape.DebugString()));
        if (TF_PREDICT_TRUE(UNINIT_BATCH_SIZE == batch_size)) {
          batch_size = shape.dim_size(0);
          k_batch_size = k_shape.dim_size(0);
          TFNN_ASSERT(
              batch_size > 0,
              errors::Internal(
                  "incorrect internal batch size inferred from input tensor ",
                  input_names.s(idx), " with shape ", shape.DebugString()));
        } else {
          TFNN_ASSERT(
              batch_size == shape.dim_size(0),
              errors::InvalidArgument(
                  "incorrect batch size found on input tensor ",
                  input_names.s(idx), ", tensor shape ", shape.DebugString(),
                  ", internal batch size ", batch_size));
        }
        shape.RemoveDim(0);
        k_shape.RemoveDim(0);
        is_batch_tensor = batch_size != k_batch_size;
        use_dynamic_batch_size |= is_batch_tensor;
      }
      TFNN_ASSERT(
          shape == k_shape,
          errors::InvalidArgument(
              "incorrect shape found on input tensor ", input_names.s(idx),
              ", inference time shape ", in_tensor.shape().DebugString(),
              ", expected shape ", input_shapes.shape(idx).DebugString()));
      is_batch_inputs[idx] = is_batch_tensor;
    }
    for (auto idx = 0; idx < output_names.s_size(); ++idx) {
      bool is_batch_tensor = false;
      if (TF_PREDICT_TRUE(0 == output_batch_axis.i(idx))) {
        TensorShape k_shape(output_shapes.shape(idx));
        TFNN_ASSERT(k_shape.dims() > 0,
                    errors::InvalidArgument(
                        "no batch-dimension found on output tensor ",
                        output_names.s(idx), " with Neuron shape ",
                        k_shape.DebugString()));
        TFNN_ASSERT(
            k_batch_size == k_shape.dim_size(0),
            errors::InvalidArgument(
                "incorrect batch size found on output tensor ",
                output_names.s(idx), ", Neuron tensor shape ",
                k_shape.DebugString(), ", Neuron batch size ", k_batch_size));
        is_batch_tensor = batch_size != k_shape.dim_size(0);
        use_dynamic_batch_size |= is_batch_tensor;
      }
      is_batch_outputs[idx] = is_batch_tensor;
    }
  }
  TFNN_ASSERT(ctx->num_outputs() == output_names.s_size(),
              errors::InvalidArgument("incorrect number of output tensors"));

  // allocate output tensors
  std::vector<Tensor*> output_tensors(ctx->num_outputs());
  int64_t pad_batch_size = 0;
  if (use_dynamic_batch_size) {
    pad_batch_size = ((batch_size - 1) / k_batch_size + 1) * k_batch_size;
    VLOG(1) << "batch_size=" << batch_size << ", k_batch_size=" << k_batch_size
            << ", pad_batch_size=" << pad_batch_size;
    for (auto idx = 0; idx < ctx->num_outputs(); ++idx) {
      Tensor* batch_out_tensor = nullptr;
      TensorShape shape(output_shapes.shape(idx));
      if (TF_PREDICT_TRUE(is_batch_outputs[idx])) {
        shape.set_dim(0, batch_size);
      }
      TF_RETURN_IF_ERROR(ctx->allocate_output(idx, shape, &batch_out_tensor));
      output_tensors[idx] = batch_out_tensor;
    }
  } else {
    bool use_shm = UseShmForIO(input_tensors);
    TF_RETURN_IF_ERROR(allocate_outputs(ctx, use_shm, output_shapes,
                                        &output_tensors));
  }

  // initialize the model
  RIE_IGNORE_ABORTED(initialize(node_def, ctx->session_handle()));

  // need an extra unary grpc call to re-establish channel in case of seeing grpc 14
  // as start_model_unsafe may not call grpc start
  RIE_IGNORE_ABORTED(neuron_engine_->start_ping());

  // keep a shared pointer so that RuntimeSession outlives shared memory buffers
  std::shared_ptr<RuntimeSession> session_alive = neuron_engine_->get_session();

  // get thread pool associated with the engine
  thread::ThreadPool* thread_pool = neuron_engine_->get_thread_pool();

  // run inference
  if (use_dynamic_batch_size) {
    int64 end_start = k_batch_size - (pad_batch_size - batch_size);
    bool run_profiler_in_shard = false;
    Status status_sd;
#define SHARD_LOG_ERROR(status_sd, ...)                            \
  {                                                                \
    Status _status = (__VA_ARGS__);                                \
    if (TF_PREDICT_FALSE(!_status.ok())) {                         \
      LOG(ERROR) << "shard error code " << _status.code()          \
                 << ", error message " << _status.error_message(); \
      status_sd = _status;                                         \
      return;                                                      \
    }                                                              \
  }
#define SHARD_LOG_IGNORE_ABORTED(status_sd, ...)                      \
  {                                                                   \
    Status _status(__VA_ARGS__);                                      \
    if (TF_PREDICT_FALSE(                                             \
            !(_status.ok() ||                                         \
              _status.code() == tensorflow::error::Code::ABORTED))) { \
      LOG(ERROR) << "shard error code " << _status.code()             \
                 << ", error message " << _status.error_message();    \
      status_sd = _status;                                            \
      return;                                                         \
    }                                                                 \
  }
#define SHARD_VLOG_TIME(msg) VLOG_TIME_BASE(start_time, 2, msg);
    auto ShardFunc = [&](int64 dim0_start, int64 dim0_limit) {
      SHARD_VLOG_TIME("entering shard");
      if (TF_PREDICT_FALSE(dim0_limit - dim0_start != k_batch_size)) {
        status_sd =
            errors::Internal("illegal shard ", dim0_start, ":", dim0_limit);
        return;
      }
      VLOG(2) << "Sharding " << dim0_start << " to " << dim0_limit;
      std::vector<Tensor> sliced_inputs(input_tensors.size());
      for (size_t idx = 0; idx < input_tensors.size(); ++idx) {
        const Tensor& in_tensor = input_tensors.at(idx);
        if (TF_PREDICT_TRUE(is_batch_inputs[idx])) {
          if (TF_PREDICT_FALSE(dim0_limit > batch_size)) {
            TensorShape ps_shape(in_tensor.shape());
            ps_shape.set_dim(0, k_batch_size);
            Tensor pad_end_slice(in_tensor.dtype(), ps_shape);
            Tensor zero_slice = pad_end_slice.Slice(end_start, k_batch_size);
            SHARD_LOG_ERROR(status_sd, tensor_memset(&zero_slice, 0));
            Tensor end_slice = in_tensor.Slice(dim0_start, batch_size);
            SHARD_LOG_ERROR(status_sd, tensor_copy(&pad_end_slice, end_slice));
            sliced_inputs[idx] = pad_end_slice;
          } else {
            sliced_inputs[idx] = in_tensor.Slice(dim0_start, dim0_limit);
          }
        } else {
          sliced_inputs[idx] = in_tensor;
        }
      }
      SHARD_LOG_ERROR(status_sd, check_input_tensors(sliced_inputs, node_def));
      int64 end_limit = dim0_limit < batch_size ? dim0_limit : batch_size;
      std::vector<Tensor> sliced_outputs(output_tensors.size());
      for (size_t idx = 0; idx < sliced_outputs.size(); ++idx) {
        Tensor* out_tensor = output_tensors.at(idx);
        if (TF_PREDICT_TRUE(is_batch_outputs[idx])) {
          sliced_outputs[idx] = out_tensor->Slice(dim0_start, end_limit);
        } else {
          sliced_outputs[idx] = *out_tensor;
        }
      }
      std::vector<Tensor*> output_ptrs(sliced_outputs.size());
      for (size_t idx = 0; idx < output_ptrs.size(); ++idx) {
        output_ptrs[idx] = &sliced_outputs.at(idx);
      }
      RuntimeIO runtime_io;
      std::vector<Tensor> input_shm_tensors;
      std::vector<Tensor> output_shm_tensors;
      bool use_shm = UseShmForIO(sliced_inputs);
      if (TF_PREDICT_TRUE(use_shm)) {
        input_shm_tensors.resize(sliced_inputs.size());
        for (size_t idx = 0; idx < sliced_inputs.size(); ++idx) {
          const Tensor& tensor = sliced_inputs.at(idx);
          TensorShape shape = tensor.shape();
          DataType dtype = tensor.dtype();
          Tensor& shm_tensor = input_shm_tensors.at(idx);
          SHARD_LOG_ERROR(status_sd,
                          allocate_temp(ctx, dtype, shape, &shm_tensor));
        }
        output_shm_tensors.resize(sliced_outputs.size());
        for (size_t idx = 0; idx < output_shm_tensors.size(); ++idx) {
          const Tensor& tensor = sliced_outputs.at(idx);
          TensorShape shape(tensor.shape());
          if (TF_PREDICT_TRUE(is_batch_outputs[idx])) {
            if (TF_PREDICT_FALSE(dim0_limit > batch_size)) {
              shape.set_dim(0, k_batch_size);
            }
          }
          DataType dtype(tensor.dtype());
          Tensor& shm_tensor = output_shm_tensors.at(idx);
          SHARD_LOG_ERROR(status_sd,
                          allocate_temp(ctx, dtype, shape, &shm_tensor));
        }
      }
      std::vector<Tensor*> output_shm_ptrs;
      for (Tensor& shm_tensor : output_shm_tensors) {
        output_shm_ptrs.push_back(&shm_tensor);
      }
      SHARD_LOG_IGNORE_ABORTED(
          status_sd, setup_runtime_io(&runtime_io, node_def, input_shm_tensors,
                                      output_shm_ptrs, nn_id_, shm_allocator,
                                      use_shm));

      // copy input tensors with optional input_shuffles
      SHARD_VLOG_TIME("in shard before input copy");
      std::vector<bool> need_copy_inputs(sliced_inputs.size(), true);
      if (k_batch_size > 1 && runtime_io.use_shm()) {
        auto CopyInputShardFunc = [&](int64 dim0_start, int64 dim0_limit) {
          std::vector<Tensor> input_slices(sliced_inputs.size());
          for (size_t i = 0; i < input_slices.size(); ++i) {
            if (TF_PREDICT_TRUE(is_batch_inputs[i])) {
              input_slices[i] = sliced_inputs[i].Slice(dim0_start, dim0_limit);
            } else {
              input_slices[i] = input_tensors[i];
            }
          }
          std::vector<Tensor> input_shm_slices(sliced_inputs.size());
          for (size_t i = 0; i < input_shm_slices.size(); ++i) {
            Tensor& shm_tensor = input_shm_tensors.at(i);
            if (TF_PREDICT_TRUE(is_batch_inputs[i])) {
              input_shm_slices[i] = shm_tensor.Slice(dim0_start, dim0_limit);
            } else {
              input_shm_slices[i] = shm_tensor;
            }
          }
          SHARD_LOG_IGNORE_ABORTED(
              status_sd, copy_input_tensors_with_shuffle(
                             ctx, node_def, nullptr, input_slices,
                             need_copy_inputs,&runtime_io, &input_shm_slices));
        };
        h2d_transfer_pool_.ParallelFor(k_batch_size, input_copy_cost_per_unit,
                                       std::move(CopyInputShardFunc));
      } else {
        SHARD_LOG_IGNORE_ABORTED(
            status_sd, copy_input_tensors_with_shuffle(
                           ctx, node_def, &h2d_transfer_pool_, sliced_inputs,
                           need_copy_inputs, &runtime_io, &input_shm_tensors));
      }

      // run inference
      SHARD_VLOG_TIME("in shard before infer");
      if (TF_PREDICT_FALSE(run_profiler_in_shard)) {
        VLOG(1) << "enabling profiler in shard";
        SHARD_LOG_IGNORE_ABORTED(
            status_sd,
            neuron_engine_->infer_with_profiling(&runtime_io, &profile_));
      } else {
        SHARD_LOG_IGNORE_ABORTED(status_sd, neuron_engine_->infer(&runtime_io));
      }
      SHARD_VLOG_TIME("in shard after infer");
      SHARD_LOG_IGNORE_ABORTED(
          status_sd, runtime_io.finish(&output_ptrs, output_shm_tensors,
                                       &h2d_transfer_pool_));
      SHARD_VLOG_TIME("in shard exit");
    };
#undef SHARD_LOG_IGNORE_ABORTED
#undef SHARD_LOG_ERROR
#undef SHARD_VLOG_TIME
    if (TF_PREDICT_FALSE(profile_.enabled_)) {
      run_profiler_in_shard = true;
      ShardFunc(0, k_batch_size);
      run_profiler_in_shard = false;
      RIE_IGNORE_ABORTED(status_sd);
    }
    VLOG_TIME("before sharding");
#if TF_VERSION_LESS_THAN(2, 0)
    thread_pool->TransformRangeConcurrently(k_batch_size, pad_batch_size,
                                            std::move(ShardFunc));
#else
    auto strategy = thread::ThreadPool::SchedulingStrategy::kFixedBlockSize;
    int64 cost_per_unit = estimated_cost_;
    auto params = thread::ThreadPool::SchedulingParams(strategy, cost_per_unit,
                                                       k_batch_size);
    thread_pool->ParallelFor(pad_batch_size, params, std::move(ShardFunc));
#endif
    RIE_IGNORE_ABORTED(status_sd);
  } else {
    TF_RETURN_IF_ERROR(check_input_tensors(input_tensors, node_def));
    std::vector<bool> need_copy_inputs(input_tensors.size(), true);
    bool need_input_shuffles = attr.count(kInputShuffles);
    if (TF_PREDICT_TRUE(shm_allocator->is_valid() && !need_input_shuffles)) {
      for (size_t idx = 0; idx < need_copy_inputs.size(); ++idx) {
        const Tensor& tensor = input_tensors.at(idx);
        need_copy_inputs[idx] = !shm_allocator->is_shm_tensor(tensor);
        VLOG(1) << "input " << idx << " need copy " << need_copy_inputs[idx];
      }
    }
    RuntimeIO runtime_io;
    std::vector<Tensor> input_shm_tensors;
    std::vector<Tensor> output_shm_tensors;
    bool use_shm = UseShmForIO(input_tensors);
    if (TF_PREDICT_TRUE(use_shm)) {
      input_shm_tensors.resize(input_tensors.size());
      for (size_t idx = 0; idx < input_shm_tensors.size(); ++idx) {
        const Tensor& tensor = input_tensors.at(idx);
        if (need_copy_inputs.at(idx)) {
          TensorShape shape = tensor.shape();
          DataType dtype = tensor.dtype();
          Tensor& shm_tensor = input_shm_tensors.at(idx);
          TF_RETURN_IF_ERROR(allocate_temp(ctx, dtype, shape, &shm_tensor));
        } else {
          input_shm_tensors[idx] = tensor;
        }
      }
    }
    RIE_IGNORE_ABORTED(setup_runtime_io(&runtime_io, node_def,
                                        input_shm_tensors, output_tensors,
                                        nn_id_, shm_allocator, use_shm));

    // copy input tensors with optional input_shuffles
    RIE_IGNORE_ABORTED(copy_input_tensors_with_shuffle(
        ctx, node_def, thread_pool, input_tensors, need_copy_inputs, &runtime_io,
        &input_shm_tensors));

    // run inference
    VLOG_TIME("before infer");
    if (TF_PREDICT_FALSE(profile_.enabled_)) {
      VLOG(1) << "profile enabled -- lock stop/start/infer altogether";
      RIE_IGNORE_ABORTED(
          neuron_engine_->infer_with_profiling(&runtime_io, &profile_));
    } else {
      RIE_IGNORE_ABORTED(neuron_engine_->infer(&runtime_io));
    }
    VLOG_TIME("after infer");
    if (TF_PREDICT_FALSE(!use_shm)) {
      RIE_IGNORE_ABORTED(
          runtime_io.finish(&output_tensors, output_shm_tensors, thread_pool));
    }
  }
  VLOG_TIME("exiting compute");
#undef VLOG_TIME
  return Status::OK();
}