std::shared_ptr VulkanModuleNode::GetPipeline()

in src/runtime/vulkan/vulkan_wrapped_func.cc [222:405]


std::shared_ptr<VulkanPipeline> VulkanModuleNode::GetPipeline(size_t device_id,
                                                              const std::string& func_name,
                                                              size_t num_pack_args) {
  auto& device = VulkanDeviceAPI::Global()->device(device_id);
  std::lock_guard<std::mutex> lock(mutex_);
  const auto& cp = ecache_[device_id][func_name];
  if (cp) {
    return cp;
  }
  // Create new pipeline
  auto pe = std::make_shared<VulkanPipeline>();
  {
    // create shader
    auto sit = smap_.find(func_name);
    ICHECK(sit != smap_.end());
    pe->use_ubo = sit->second.flag & (1 << ShaderMetaDataFlagMask::kUseUBO);
    const std::vector<uint32_t>& data = sit->second.data;
    VkShaderModuleCreateInfo shader_cinfo;
    shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
    shader_cinfo.pNext = nullptr;
    shader_cinfo.flags = 0;
    shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
    shader_cinfo.pCode = data.data();
    VULKAN_CALL(vkCreateShaderModule(device, &shader_cinfo, nullptr, &(pe->shader)));
  }
  std::vector<VkDescriptorSetLayoutBinding> arg_binding;
  std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
  std::vector<VkDescriptorPoolSize> descriptor_set_pool_sizes;
  uint32_t num_pod = 0, num_buffer = 0;

  auto push_arg_info = [&arg_binding, &arg_template, &descriptor_set_pool_sizes](
                           uint32_t binding, VkDescriptorType desc_type) {
    {
      auto result = std::find_if(descriptor_set_pool_sizes.begin(), descriptor_set_pool_sizes.end(),
                                 [&](const auto& psize) { return psize.type == desc_type; });
      if (result == descriptor_set_pool_sizes.end()) {
        VkDescriptorPoolSize new_size;
        new_size.type = desc_type;
        new_size.descriptorCount = 1;
        descriptor_set_pool_sizes.push_back(new_size);
      } else {
        result->descriptorCount++;
      }
    }

    {
      VkDescriptorSetLayoutBinding bd;
      bd.binding = binding;
      bd.descriptorType = desc_type;
      bd.descriptorCount = 1;
      bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
      bd.pImmutableSamplers = nullptr;
      arg_binding.push_back(bd);
    }
    {
      VkDescriptorUpdateTemplateEntryKHR tpl;
      tpl.dstBinding = binding;
      tpl.dstArrayElement = 0;
      tpl.descriptorCount = 1;
      tpl.descriptorType = desc_type;
      tpl.offset = binding * sizeof(VkDescriptorBufferInfo);
      tpl.stride = sizeof(VkDescriptorBufferInfo);
      arg_template.push_back(tpl);
    }
  };

  {
    auto fit = fmap_.find(func_name);
    ICHECK(fit != fmap_.end());
    for (DLDataType arg_type : fit->second.arg_types) {
      if (arg_type.code == kTVMOpaqueHandle) {
        push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
        ++num_buffer;
      } else {
        ++num_pod;
      }
    }
  }

  size_t nbytes_scalars = num_pod * sizeof(ArgUnion64);
  if (pe->use_ubo) {
    // Use UBO instead of push constants
    push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
    device.AllocateThreadLocalUniformBuffer(nbytes_scalars);
  }

  {
    VkDescriptorSetLayoutCreateInfo descrip_cinfo;
    descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
    descrip_cinfo.pNext = nullptr;
    descrip_cinfo.flags = 0;
    if (device.UseImmediate()) {
      descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
    }
    descrip_cinfo.bindingCount = arg_binding.size();
    descrip_cinfo.pBindings = arg_binding.data();
    VULKAN_CALL(
        vkCreateDescriptorSetLayout(device, &descrip_cinfo, nullptr, &(pe->descriptor_set_layout)));
  }

  if (!device.UseImmediate()) {
    VkDescriptorPoolCreateInfo descrip_pool_cinfo;
    descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
    descrip_pool_cinfo.pNext = nullptr;
    descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
    descrip_pool_cinfo.maxSets = 1;
    descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size();
    descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data();
    VULKAN_CALL(
        vkCreateDescriptorPool(device, &descrip_pool_cinfo, nullptr, &(pe->descriptor_pool)));

    VkDescriptorSetAllocateInfo alloc_info;
    alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
    alloc_info.pNext = nullptr;
    alloc_info.descriptorPool = pe->descriptor_pool;
    alloc_info.descriptorSetCount = 1;
    alloc_info.pSetLayouts = &(pe->descriptor_set_layout);
    VULKAN_CALL(vkAllocateDescriptorSets(device, &alloc_info, &(pe->descriptor_set)));
  }

  VkPushConstantRange crange;
  crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
  crange.offset = 0;
  crange.size = sizeof(ArgUnion64) * num_pack_args;

  VkPipelineLayoutCreateInfo playout_cinfo;
  playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
  playout_cinfo.pNext = nullptr;
  playout_cinfo.flags = 0;
  playout_cinfo.setLayoutCount = 1;
  playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);

  if (0 < nbytes_scalars && !pe->use_ubo) {
    playout_cinfo.pushConstantRangeCount = 1;
    playout_cinfo.pPushConstantRanges = &crange;
    ICHECK_LE(crange.size, device.device_properties.max_push_constants_size)
        << "The Vulkan shader uses " << crange.size
        << " bytes of push constants, but the device only supports "
        << device.device_properties.max_push_constants_size << "bytes. "
        << "Please rebuild the shader using a smaller limit on push constants size "
        << "by passing -max_push_constants_size=N in the Target string, "
        << "or pass -from_device=0 to query all device parameters.";
  } else {
    playout_cinfo.pushConstantRangeCount = 0;
    playout_cinfo.pPushConstantRanges = nullptr;
  }

  VULKAN_CALL(vkCreatePipelineLayout(device, &playout_cinfo, nullptr, &(pe->pipeline_layout)));

  VkComputePipelineCreateInfo pipeline_cinfo;
  pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
  pipeline_cinfo.pNext = nullptr;
  pipeline_cinfo.flags = 0;
  pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
  pipeline_cinfo.stage.pNext = nullptr;
  pipeline_cinfo.stage.flags = 0;
  pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
  pipeline_cinfo.stage.module = pe->shader;
  pipeline_cinfo.stage.pName = func_name.c_str();
  pipeline_cinfo.stage.pSpecializationInfo = nullptr;
  pipeline_cinfo.layout = pe->pipeline_layout;
  pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
  pipeline_cinfo.basePipelineIndex = 0;
  VULKAN_CALL(vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
                                       &(pe->pipeline)));

  if (device.UseImmediate()) {
    VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
    descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
    descrip_template_cinfo.pNext = nullptr;
    descrip_template_cinfo.flags = 0;
    descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size();
    descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data();
    descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR;
    descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout;
    descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
    descrip_template_cinfo.pipelineLayout = pe->pipeline_layout;
    descrip_template_cinfo.set = 0;
    VULKAN_CALL(device.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR(
        device, &descrip_template_cinfo, nullptr, &(pe->descriptor_update_template)));
  }
  ecache_[device_id][func_name] = pe;
  return pe;
}