in src/runtime/vulkan/vulkan_wrapped_func.cc [222:405]
std::shared_ptr<VulkanPipeline> VulkanModuleNode::GetPipeline(size_t device_id,
const std::string& func_name,
size_t num_pack_args) {
auto& device = VulkanDeviceAPI::Global()->device(device_id);
std::lock_guard<std::mutex> lock(mutex_);
const auto& cp = ecache_[device_id][func_name];
if (cp) {
return cp;
}
// Create new pipeline
auto pe = std::make_shared<VulkanPipeline>();
{
// create shader
auto sit = smap_.find(func_name);
ICHECK(sit != smap_.end());
pe->use_ubo = sit->second.flag & (1 << ShaderMetaDataFlagMask::kUseUBO);
const std::vector<uint32_t>& data = sit->second.data;
VkShaderModuleCreateInfo shader_cinfo;
shader_cinfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
shader_cinfo.pNext = nullptr;
shader_cinfo.flags = 0;
shader_cinfo.codeSize = data.size() * sizeof(uint32_t);
shader_cinfo.pCode = data.data();
VULKAN_CALL(vkCreateShaderModule(device, &shader_cinfo, nullptr, &(pe->shader)));
}
std::vector<VkDescriptorSetLayoutBinding> arg_binding;
std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
std::vector<VkDescriptorPoolSize> descriptor_set_pool_sizes;
uint32_t num_pod = 0, num_buffer = 0;
auto push_arg_info = [&arg_binding, &arg_template, &descriptor_set_pool_sizes](
uint32_t binding, VkDescriptorType desc_type) {
{
auto result = std::find_if(descriptor_set_pool_sizes.begin(), descriptor_set_pool_sizes.end(),
[&](const auto& psize) { return psize.type == desc_type; });
if (result == descriptor_set_pool_sizes.end()) {
VkDescriptorPoolSize new_size;
new_size.type = desc_type;
new_size.descriptorCount = 1;
descriptor_set_pool_sizes.push_back(new_size);
} else {
result->descriptorCount++;
}
}
{
VkDescriptorSetLayoutBinding bd;
bd.binding = binding;
bd.descriptorType = desc_type;
bd.descriptorCount = 1;
bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
bd.pImmutableSamplers = nullptr;
arg_binding.push_back(bd);
}
{
VkDescriptorUpdateTemplateEntryKHR tpl;
tpl.dstBinding = binding;
tpl.dstArrayElement = 0;
tpl.descriptorCount = 1;
tpl.descriptorType = desc_type;
tpl.offset = binding * sizeof(VkDescriptorBufferInfo);
tpl.stride = sizeof(VkDescriptorBufferInfo);
arg_template.push_back(tpl);
}
};
{
auto fit = fmap_.find(func_name);
ICHECK(fit != fmap_.end());
for (DLDataType arg_type : fit->second.arg_types) {
if (arg_type.code == kTVMOpaqueHandle) {
push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
++num_buffer;
} else {
++num_pod;
}
}
}
size_t nbytes_scalars = num_pod * sizeof(ArgUnion64);
if (pe->use_ubo) {
// Use UBO instead of push constants
push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
device.AllocateThreadLocalUniformBuffer(nbytes_scalars);
}
{
VkDescriptorSetLayoutCreateInfo descrip_cinfo;
descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
descrip_cinfo.pNext = nullptr;
descrip_cinfo.flags = 0;
if (device.UseImmediate()) {
descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR;
}
descrip_cinfo.bindingCount = arg_binding.size();
descrip_cinfo.pBindings = arg_binding.data();
VULKAN_CALL(
vkCreateDescriptorSetLayout(device, &descrip_cinfo, nullptr, &(pe->descriptor_set_layout)));
}
if (!device.UseImmediate()) {
VkDescriptorPoolCreateInfo descrip_pool_cinfo;
descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
descrip_pool_cinfo.pNext = nullptr;
descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
descrip_pool_cinfo.maxSets = 1;
descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size();
descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data();
VULKAN_CALL(
vkCreateDescriptorPool(device, &descrip_pool_cinfo, nullptr, &(pe->descriptor_pool)));
VkDescriptorSetAllocateInfo alloc_info;
alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
alloc_info.pNext = nullptr;
alloc_info.descriptorPool = pe->descriptor_pool;
alloc_info.descriptorSetCount = 1;
alloc_info.pSetLayouts = &(pe->descriptor_set_layout);
VULKAN_CALL(vkAllocateDescriptorSets(device, &alloc_info, &(pe->descriptor_set)));
}
VkPushConstantRange crange;
crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
crange.offset = 0;
crange.size = sizeof(ArgUnion64) * num_pack_args;
VkPipelineLayoutCreateInfo playout_cinfo;
playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
playout_cinfo.pNext = nullptr;
playout_cinfo.flags = 0;
playout_cinfo.setLayoutCount = 1;
playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
if (0 < nbytes_scalars && !pe->use_ubo) {
playout_cinfo.pushConstantRangeCount = 1;
playout_cinfo.pPushConstantRanges = &crange;
ICHECK_LE(crange.size, device.device_properties.max_push_constants_size)
<< "The Vulkan shader uses " << crange.size
<< " bytes of push constants, but the device only supports "
<< device.device_properties.max_push_constants_size << "bytes. "
<< "Please rebuild the shader using a smaller limit on push constants size "
<< "by passing -max_push_constants_size=N in the Target string, "
<< "or pass -from_device=0 to query all device parameters.";
} else {
playout_cinfo.pushConstantRangeCount = 0;
playout_cinfo.pPushConstantRanges = nullptr;
}
VULKAN_CALL(vkCreatePipelineLayout(device, &playout_cinfo, nullptr, &(pe->pipeline_layout)));
VkComputePipelineCreateInfo pipeline_cinfo;
pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
pipeline_cinfo.pNext = nullptr;
pipeline_cinfo.flags = 0;
pipeline_cinfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
pipeline_cinfo.stage.pNext = nullptr;
pipeline_cinfo.stage.flags = 0;
pipeline_cinfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
pipeline_cinfo.stage.module = pe->shader;
pipeline_cinfo.stage.pName = func_name.c_str();
pipeline_cinfo.stage.pSpecializationInfo = nullptr;
pipeline_cinfo.layout = pe->pipeline_layout;
pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE;
pipeline_cinfo.basePipelineIndex = 0;
VULKAN_CALL(vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
&(pe->pipeline)));
if (device.UseImmediate()) {
VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
descrip_template_cinfo.pNext = nullptr;
descrip_template_cinfo.flags = 0;
descrip_template_cinfo.descriptorUpdateEntryCount = arg_template.size();
descrip_template_cinfo.pDescriptorUpdateEntries = arg_template.data();
descrip_template_cinfo.templateType = VK_DESCRIPTOR_UPDATE_TEMPLATE_TYPE_PUSH_DESCRIPTORS_KHR;
descrip_template_cinfo.descriptorSetLayout = pe->descriptor_set_layout;
descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
descrip_template_cinfo.pipelineLayout = pe->pipeline_layout;
descrip_template_cinfo.set = 0;
VULKAN_CALL(device.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR(
device, &descrip_template_cinfo, nullptr, &(pe->descriptor_update_template)));
}
ecache_[device_id][func_name] = pe;
return pe;
}