in src/runtime/vulkan/vulkan_wrapped_func.cc [45:185]
void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
const ArgUnion64* pack_args) const {
int device_id = VulkanDeviceAPI::Global()->GetActiveDeviceID();
auto& device = VulkanDeviceAPI::Global()->device(device_id);
if (!scache_[device_id]) {
scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_);
}
const auto& pipeline = scache_[device_id];
ThreadWorkLoad wl = launch_param_config_.Extract(args);
std::vector<VkDescriptorBufferInfo> descriptor_buffers;
descriptor_buffers.resize(num_buffer_args_);
for (size_t i = 0; i < num_buffer_args_; ++i) {
void* buf = args[static_cast<int>(i)].cast<void*>();
VkDescriptorBufferInfo binfo;
binfo.buffer = static_cast<VulkanBuffer*>(buf)->buffer;
binfo.offset = 0;
binfo.range = VK_WHOLE_SIZE;
descriptor_buffers[i] = binfo;
}
const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64);
if (pipeline->use_ubo) {
auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars);
VkDescriptorBufferInfo binfo;
binfo.buffer = ubo.vk_buf.buffer;
binfo.offset = 0;
binfo.range = VK_WHOLE_SIZE;
descriptor_buffers.push_back(binfo);
}
if (device.UseImmediate()) {
// Can safely capture by reference as this lambda is immediately executed on the calling thread.
device.ThreadLocalStream().Launch([&](VulkanStreamState* state) {
vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE);
device.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
descriptor_buffers.data());
if (pipeline->use_ubo) {
auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars);
memcpy(ubo.host_addr, pack_args, nbytes_scalars);
} else if (num_pack_args_ > 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
pack_args);
}
vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1, &barrier_info, 0, nullptr, 0, nullptr);
if (device.UseDebugUtilsLabel()) {
VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
nullptr,
func_name_.c_str(),
{0.0f, 0.0f, 0.0f, 0.0f}};
device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT(
device.Queue(), &dispatch_label);
}
});
return;
}
// Otherwise, the more expensive deferred path.
std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
const auto& deferred_initializer = [&device, pipeline, descriptor_buffers]() {
std::vector<VkWriteDescriptorSet> write_descriptor_sets;
write_descriptor_sets.resize(descriptor_buffers.size());
for (size_t i = 0; i < write_descriptor_sets.size(); i++) {
write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
write_descriptor_sets[i].pNext = nullptr;
write_descriptor_sets[i].dstSet = pipeline->descriptor_set;
write_descriptor_sets[i].dstBinding = i;
write_descriptor_sets[i].dstArrayElement = 0;
write_descriptor_sets[i].descriptorCount = 1;
write_descriptor_sets[i].pImageInfo = nullptr;
write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
write_descriptor_sets[i].pTexelBufferView = nullptr;
if (pipeline->use_ubo && i == write_descriptor_sets.size() - 1) {
// The last binding is for UBO
write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
} else {
write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
}
}
vkUpdateDescriptorSets(device, write_descriptor_sets.size(), write_descriptor_sets.data(), 0,
nullptr);
};
const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars,
device_id](VulkanStreamState* state) {
auto& device = VulkanDeviceAPI::Global()->device(device_id);
vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
nullptr);
if (pipeline->use_ubo) {
auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars);
memcpy(ubo.host_addr, pack_args_storage.data(), nbytes_scalars);
} else if (num_pack_args_ > 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
0, pack_args_storage.size() * sizeof(ArgUnion64),
pack_args_storage.data());
}
vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
VkMemoryBarrier barrier_info;
barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
barrier_info.pNext = nullptr;
barrier_info.srcAccessMask = VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT;
barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT |
VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT);
vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1, &barrier_info, 0, nullptr, 0, nullptr);
};
VulkanStreamToken deferred_token;
deferred_token.descriptor_set_ = pipeline->descriptor_set;
deferred_token.buffers_.resize(descriptor_buffers.size());
for (size_t i = 0; i < descriptor_buffers.size(); ++i) {
deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
}
device.ThreadLocalStream().LaunchDeferred(deferred_initializer, deferred_kernel, deferred_token);
if (device.UseDebugUtilsLabel()) {
VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
nullptr,
func_name_.c_str(),
{0.0f, 0.0f, 0.0f, 0.0f}};
device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT(
device.Queue(), &dispatch_label);
}
}