void VulkanWrappedFunc::operator()

in src/runtime/vulkan/vulkan_wrapped_func.cc [45:185]


void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
                                   const ArgUnion64* pack_args) const {
  int device_id = VulkanDeviceAPI::Global()->GetActiveDeviceID();
  auto& device = VulkanDeviceAPI::Global()->device(device_id);
  if (!scache_[device_id]) {
    scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
  }
  const auto& pipeline = scache_[device_id];
  ThreadWorkLoad wl = launch_param_config_.Extract(args);
  std::vector<VkDescriptorBufferInfo> descriptor_buffers;
  descriptor_buffers.resize(num_buffer_args_);
  for (size_t i = 0; i < num_buffer_args_; ++i) {
    void* buf = args[static_cast<int>(i)].cast<void*>();
    VkDescriptorBufferInfo binfo;
    binfo.buffer = static_cast<VulkanBuffer*>(buf)->buffer;
    binfo.offset = 0;
    binfo.range = VK_WHOLE_SIZE;
    descriptor_buffers[i] = binfo;
  }
  const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64);
  if (pipeline->use_ubo) {
    auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars);
    VkDescriptorBufferInfo binfo;
    binfo.buffer = ubo.vk_buf.buffer;
    binfo.offset = 0;
    binfo.range = VK_WHOLE_SIZE;
    descriptor_buffers.push_back(binfo);
  }
  if (device.UseImmediate()) {
    // Can safely capture by reference as this lambda is immediately executed on the calling thread.
    device.ThreadLocalStream().Launch([&](VulkanStreamState* state) {
      vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
      ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE);
      device.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
          state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
          descriptor_buffers.data());

      if (pipeline->use_ubo) {
        auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars);
        memcpy(ubo.host_addr, pack_args, nbytes_scalars);
      } else if (num_pack_args_ > 0) {
        vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
                           VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
                           pack_args);
      }

      vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
      VkMemoryBarrier barrier_info;
      barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
      barrier_info.pNext = nullptr;
      barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
      barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
                                    VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
      vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
                           VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
                           1, &barrier_info, 0, nullptr, 0, nullptr);

      if (device.UseDebugUtilsLabel()) {
        VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
                                               nullptr,
                                               func_name_.c_str(),
                                               {0.0f, 0.0f, 0.0f, 0.0f}};
        device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT(
            device.Queue(), &dispatch_label);
      }
    });
    return;
  }

  // Otherwise, the more expensive deferred path.
  std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
  const auto& deferred_initializer = [&device, pipeline, descriptor_buffers]() {
    std::vector<VkWriteDescriptorSet> write_descriptor_sets;
    write_descriptor_sets.resize(descriptor_buffers.size());
    for (size_t i = 0; i < write_descriptor_sets.size(); i++) {
      write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
      write_descriptor_sets[i].pNext = nullptr;
      write_descriptor_sets[i].dstSet = pipeline->descriptor_set;
      write_descriptor_sets[i].dstBinding = i;
      write_descriptor_sets[i].dstArrayElement = 0;
      write_descriptor_sets[i].descriptorCount = 1;
      write_descriptor_sets[i].pImageInfo = nullptr;
      write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
      write_descriptor_sets[i].pTexelBufferView = nullptr;

      if (pipeline->use_ubo && i == write_descriptor_sets.size() - 1) {
        // The last binding is for UBO
        write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
      } else {
        write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
      }
    }
    vkUpdateDescriptorSets(device, write_descriptor_sets.size(), write_descriptor_sets.data(), 0,
                           nullptr);
  };
  const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars,
                                 device_id](VulkanStreamState* state) {
    auto& device = VulkanDeviceAPI::Global()->device(device_id);

    vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
    vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
                            pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
                            nullptr);

    if (pipeline->use_ubo) {
      auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars);
      memcpy(ubo.host_addr, pack_args_storage.data(), nbytes_scalars);
    } else if (num_pack_args_ > 0) {
      vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
                         0, pack_args_storage.size() * sizeof(ArgUnion64),
                         pack_args_storage.data());
    }

    vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
    VkMemoryBarrier barrier_info;
    barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
    barrier_info.pNext = nullptr;
    barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
    barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
                                  VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
    vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
                         VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
                         1, &barrier_info, 0, nullptr, 0, nullptr);
  };
  VulkanStreamToken deferred_token;
  deferred_token.descriptor_set_ = pipeline->descriptor_set;
  deferred_token.buffers_.resize(descriptor_buffers.size());
  for (size_t i = 0; i < descriptor_buffers.size(); ++i) {
    deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
  }
  device.ThreadLocalStream().LaunchDeferred(deferred_initializer, deferred_kernel, deferred_token);

  if (device.UseDebugUtilsLabel()) {
    VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
                                           nullptr,
                                           func_name_.c_str(),
                                           {0.0f, 0.0f, 0.0f, 0.0f}};
    device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT(
        device.Queue(), &dispatch_label);
  }
}