void VulkanReplay::FetchMeshOut()

in renderdoc/driver/vulkan/vk_postvs.cpp [2823:4067]


void VulkanReplay::FetchMeshOut(uint32_t eventId, VulkanRenderState &state)
{
  VulkanCreationInfo &creationInfo = m_pDriver->m_CreationInfo;

  ActionDescription action = *m_pDriver->GetAction(eventId);

  // for indirect dispatches, fetch up to date dispatch sizes in case they're non-deterministic
  if(action.flags & ActionFlags::Indirect)
  {
    uint32_t chunkIdx = action.events.back().chunkIndex;

    const SDFile *file = GetStructuredFile();

    // it doesn't matter if this is an indirect sub command or an inlined 1-draw non-indirect count,
    // either way the 'offset' is valid - either from the start, or updated for this particular draw
    // when we originally patched (and fortunately that part doesn't change).
    if(chunkIdx < file->chunks.size())
    {
      const SDChunk *chunk = file->chunks[chunkIdx];

      ResourceId buf = chunk->FindChild("buffer")->AsResourceId();
      uint64_t offs = chunk->FindChild("offset")->AsUInt64();

      buf = GetResourceManager()->GetLiveID(buf);

      bytebuf dispatchArgs;
      GetBufferData(buf, offs, sizeof(VkDrawMeshTasksIndirectCommandEXT), dispatchArgs);

      if(dispatchArgs.size() >= sizeof(VkDrawMeshTasksIndirectCommandEXT))
      {
        VkDrawMeshTasksIndirectCommandEXT *meshArgs =
            (VkDrawMeshTasksIndirectCommandEXT *)dispatchArgs.data();

        action.dispatchDimension[0] = meshArgs->groupCountX;
        action.dispatchDimension[1] = meshArgs->groupCountY;
        action.dispatchDimension[2] = meshArgs->groupCountZ;
      }
    }
  }

  uint32_t totalNumMeshlets =
      action.dispatchDimension[0] * action.dispatchDimension[1] * action.dispatchDimension[2];

  const VulkanCreationInfo::Pipeline &pipeInfo = creationInfo.m_Pipeline[state.graphics.pipeline];

  const VulkanCreationInfo::Pipeline::Shader &meshShad = pipeInfo.shaders[7];

  const VulkanCreationInfo::ShaderModule &meshInfo = creationInfo.m_ShaderModule[meshShad.module];
  ShaderReflection *meshrefl = meshShad.refl;

  VulkanPostVSData &ret = m_PostVS.Data[eventId];

  // set defaults so that we don't try to fetch this output again if something goes wrong and the
  // same event is selected again
  {
    ret.meshout.buf = VK_NULL_HANDLE;
    ret.meshout.bufmem = VK_NULL_HANDLE;
    ret.meshout.instStride = 0;
    ret.meshout.vertStride = 0;
    ret.meshout.numViews = 1;
    ret.meshout.nearPlane = 0.0f;
    ret.meshout.farPlane = 0.0f;
    ret.meshout.useIndices = false;
    ret.meshout.hasPosOut = false;
    ret.meshout.flipY = false;
    ret.meshout.idxbuf = VK_NULL_HANDLE;
    ret.meshout.idxbufmem = VK_NULL_HANDLE;

    ret.meshout.topo = meshShad.refl->outputTopology;

    ret.taskout = ret.meshout;
  }

  if(meshShad.patchData->invalidTaskPayload)
  {
    ret.meshout.status = ret.taskout.status = "Invalid task payload, likely generated by dxc bug";
    return;
  }

  if(meshrefl->outputSignature.empty())
  {
    ret.meshout.status = "mesh shader has no declared outputs";
    return;
  }

  if(!m_pDriver->GetExtensions(NULL).ext_KHR_buffer_device_address ||
     Vulkan_Debug_DisableBufferDeviceAddress())
  {
    ret.meshout.status =
        "KHR_buffer_device_address extension not available, can't fetch mesh shader output";
    return;
  }

  if(!m_pDriver->GetExtensions(NULL).ext_EXT_scalar_block_layout)
  {
    ret.meshout.status =
        "EXT_scalar_block_layout extension not available, can't fetch mesh shader output";
    return;
  }

  if(!m_pDriver->GetDeviceEnabledFeatures().shaderInt64)
  {
    ret.meshout.status = "int64 device feature not available, can't fetch mesh shader output";
    return;
  }

  VkGraphicsPipelineCreateInfo pipeCreateInfo;

  // get pipeline create info
  m_pDriver->GetShaderCache()->MakeGraphicsPipelineInfo(pipeCreateInfo, state.graphics.pipeline);

  uint32_t bufSpecConstant = 0;

  bytebuf meshSpecData;
  rdcarray<VkSpecializationMapEntry> meshSpecEntries;
  bytebuf taskSpecData;
  rdcarray<VkSpecializationMapEntry> taskSpecEntries;

  // copy over specialization info
  for(uint32_t s = 0; s < pipeCreateInfo.stageCount; s++)
  {
    if(pipeCreateInfo.pStages[s].stage == VK_SHADER_STAGE_MESH_BIT_EXT &&
       pipeCreateInfo.pStages[s].pSpecializationInfo)
    {
      meshSpecData.append((const byte *)pipeCreateInfo.pStages[s].pSpecializationInfo->pData,
                          pipeCreateInfo.pStages[s].pSpecializationInfo->dataSize);
      meshSpecEntries.append(pipeCreateInfo.pStages[s].pSpecializationInfo->pMapEntries,
                             pipeCreateInfo.pStages[s].pSpecializationInfo->mapEntryCount);
    }
    else if(pipeCreateInfo.pStages[s].stage == VK_SHADER_STAGE_TASK_BIT_EXT &&
            pipeCreateInfo.pStages[s].pSpecializationInfo)
    {
      taskSpecData.append((const byte *)pipeCreateInfo.pStages[s].pSpecializationInfo->pData,
                          pipeCreateInfo.pStages[s].pSpecializationInfo->dataSize);
      taskSpecEntries.append(pipeCreateInfo.pStages[s].pSpecializationInfo->pMapEntries,
                             pipeCreateInfo.pStages[s].pSpecializationInfo->mapEntryCount);
    }
  }

  // don't overlap with existing pipeline constants
  for(const VkSpecializationMapEntry &specConst : meshSpecEntries)
    bufSpecConstant = RDCMAX(bufSpecConstant, specConst.constantID + 1);
  for(const VkSpecializationMapEntry &specConst : taskSpecEntries)
    bufSpecConstant = RDCMAX(bufSpecConstant, specConst.constantID + 1);

  // forcibly set input assembly state to NULL, as AMD's driver still processes this and may crash
  // if the contents are not sensible. Since this does nothing otherwise we don't make it conditional
  pipeCreateInfo.pInputAssemblyState = NULL;

  // use the load RP if an RP is specified
  if(pipeCreateInfo.renderPass != VK_NULL_HANDLE)
  {
    pipeCreateInfo.renderPass =
        creationInfo.m_RenderPass[GetResID(pipeCreateInfo.renderPass)].loadRPs[pipeCreateInfo.subpass];
    pipeCreateInfo.subpass = 0;
  }

  const VkMemoryAllocateFlagsInfo memFlags = {
      VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO,
      NULL,
      VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT,
  };

  // we go through the driver for all these creations since they need to be properly
  // registered in order to be put in the partial replay state
  VkResult vkr = VK_SUCCESS;
  VkDevice dev = m_Device;

  VkBuffer taskBuffer = VK_NULL_HANDLE, readbackTaskBuffer = VK_NULL_HANDLE;
  VkDeviceMemory taskMem = VK_NULL_HANDLE, readbackTaskMem = VK_NULL_HANDLE;

  VkDeviceSize taskBufSize = 0;
  uint32_t taskPayloadSize = 0;
  VkDeviceAddress taskDataAddress = 0;

  rdcarray<VulkanPostVSData::InstData> taskDispatchSizes;
  const uint32_t totalNumTaskGroups = totalNumMeshlets;

  // if we have a task shader, we fetch both outputs together as a necessary component.
  // In order to properly pre-allocate the mesh output buffer we need to run the task shader, cache
  // all of its payloads and mesh dispatches per-group, then run a dispatch for each task group that
  // passes along the cached payloads. With a CPU sync point this ensures that any non-deterministic
  // behaviour or ordering will remain consistent between both passes and still allow for the
  // allocation after we know the average case. This is necessary because with task expansion the
  // worst case buffer size could be massive
  if(pipeInfo.shaders[(size_t)ShaderStage::Task].refl)
  {
    const VulkanCreationInfo::Pipeline::Shader &taskShad =
        pipeInfo.shaders[(size_t)ShaderStage::Task];

    if(taskShad.patchData->invalidTaskPayload)
    {
      ret.meshout.status = ret.taskout.status = "Invalid task payload, likely generated by dxc bug";
      return;
    }

    const VulkanCreationInfo::ShaderModule &taskInfo = creationInfo.m_ShaderModule[taskShad.module];

    rdcarray<uint32_t> taskSpirv = taskInfo.spirv.GetSPIRV();

    if(!Vulkan_Debug_PostVSDumpDirPath().empty())
      FileIO::WriteAll(Vulkan_Debug_PostVSDumpDirPath() + "/debug_postts_before.spv", taskSpirv);

    AddTaskShaderPayloadStores(taskShad.specialization, meshShad.entryPoint, bufSpecConstant + 1,
                               taskSpirv, taskPayloadSize);

    if(!Vulkan_Debug_PostVSDumpDirPath().empty())
      FileIO::WriteAll(Vulkan_Debug_PostVSDumpDirPath() + "/debug_postts_after.spv", taskSpirv);

    {
      // now that we know the stride, create buffer of sufficient size for the worst case (maximum
      // generation) of the meshlets
      VkBufferCreateInfo bufInfo = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO};

      // we add an extra vec4u so that when feeding from this buffer we can load the oversized
      // payload, read "out of bounds" into that padding with the extra uint offset, and then fix
      // the uint offset with a composite insert
      taskBufSize = bufInfo.size =
          (taskPayloadSize + sizeof(Vec4u)) * totalNumTaskGroups + sizeof(Vec4u);

      bufInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
      bufInfo.usage |= VK_BUFFER_USAGE_TRANSFER_DST_BIT;
      bufInfo.usage |= VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT;

      vkr = m_pDriver->vkCreateBuffer(dev, &bufInfo, NULL, &taskBuffer);
      CheckVkResult(vkr);

      bufInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;

      vkr = m_pDriver->vkCreateBuffer(dev, &bufInfo, NULL, &readbackTaskBuffer);
      CheckVkResult(vkr);

      VkMemoryRequirements mrq = {0};
      m_pDriver->vkGetBufferMemoryRequirements(dev, taskBuffer, &mrq);

      VkMemoryAllocateInfo allocInfo = {
          VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO,
          NULL,
          mrq.size,
          m_pDriver->GetGPULocalMemoryIndex(mrq.memoryTypeBits),
      };

      allocInfo.pNext = &memFlags;

      vkr = m_pDriver->vkAllocateMemory(dev, &allocInfo, NULL, &taskMem);

      if(vkr == VK_ERROR_OUT_OF_DEVICE_MEMORY || vkr == VK_ERROR_OUT_OF_HOST_MEMORY)
      {
        m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
        m_pDriver->vkDestroyBuffer(m_Device, readbackTaskBuffer, NULL);

        RDCWARN("Failed to allocate %llu bytes for output", mrq.size);
        ret.meshout.status = ret.taskout.status =
            StringFormat::Fmt("Failed to allocate %llu bytes", mrq.size);
        return;
      }

      CheckVkResult(vkr);

      vkr = m_pDriver->vkBindBufferMemory(dev, taskBuffer, taskMem, 0);
      CheckVkResult(vkr);

      m_pDriver->vkGetBufferMemoryRequirements(dev, readbackTaskBuffer, &mrq);

      allocInfo.pNext = NULL;
      allocInfo.memoryTypeIndex = m_pDriver->GetReadbackMemoryIndex(mrq.memoryTypeBits);

      vkr = m_pDriver->vkAllocateMemory(dev, &allocInfo, NULL, &readbackTaskMem);

      if(vkr == VK_ERROR_OUT_OF_DEVICE_MEMORY || vkr == VK_ERROR_OUT_OF_HOST_MEMORY)
      {
        m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
        m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
        m_pDriver->vkDestroyBuffer(m_Device, readbackTaskBuffer, NULL);

        RDCWARN("Failed to allocate %llu bytes for readback", mrq.size);
        ret.meshout.status = ret.taskout.status =
            StringFormat::Fmt("Failed to allocate %llu bytes", mrq.size);
        return;
      }

      CheckVkResult(vkr);

      vkr = m_pDriver->vkBindBufferMemory(dev, readbackTaskBuffer, readbackTaskMem, 0);
      CheckVkResult(vkr);

      // register address as specialisation constant

      // ensure we're 64-bit aligned first
      taskSpecData.resize(AlignUp(taskSpecData.size(), (size_t)8));

      VkBufferDeviceAddressInfo getAddressInfo = {
          VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO,
          NULL,
          taskBuffer,
      };

      taskDataAddress = m_pDriver->vkGetBufferDeviceAddress(dev, &getAddressInfo);

      VkSpecializationMapEntry entry;
      entry.offset = (uint32_t)taskSpecData.size();
      entry.constantID = bufSpecConstant + 1;
      entry.size = sizeof(uint64_t);
      taskSpecEntries.push_back(entry);
      taskSpecData.append((const byte *)&taskDataAddress, sizeof(uint64_t));
    }

    VkSpecializationInfo taskSpecInfo = {};
    taskSpecInfo.dataSize = taskSpecData.size();
    taskSpecInfo.pData = taskSpecData.data();
    taskSpecInfo.mapEntryCount = (uint32_t)taskSpecEntries.size();
    taskSpecInfo.pMapEntries = taskSpecEntries.data();

    // create mesh shader with modified code
    VkShaderModuleCreateInfo moduleCreateInfo = {
        VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
        NULL,
        0,
        taskSpirv.size() * sizeof(uint32_t),
        taskSpirv.data(),
    };

    VkShaderModule taskModule;
    vkr = m_pDriver->vkCreateShaderModule(dev, &moduleCreateInfo, NULL, &taskModule);
    CheckVkResult(vkr);

    for(uint32_t s = 0; s < pipeCreateInfo.stageCount; s++)
    {
      if(pipeCreateInfo.pStages[s].stage == VK_SHADER_STAGE_TASK_BIT_EXT)
      {
        VkPipelineShaderStageCreateInfo &taskStage =
            (VkPipelineShaderStageCreateInfo &)pipeCreateInfo.pStages[s];
        taskStage.module = taskModule;
        taskStage.pSpecializationInfo = &taskSpecInfo;
      }
    }

    // create new pipeline
    VkPipeline taskPipe;
    vkr = m_pDriver->vkCreateGraphicsPipelines(m_Device, VK_NULL_HANDLE, 1, &pipeCreateInfo, NULL,
                                               &taskPipe);

    // delete shader/shader module
    m_pDriver->vkDestroyShaderModule(dev, taskModule, NULL);

    if(vkr != VK_SUCCESS)
    {
      m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
      m_pDriver->vkFreeMemory(m_Device, readbackTaskMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, readbackTaskBuffer, NULL);

      ret.meshout.status = ret.taskout.status =
          StringFormat::Fmt("Failed to create patched mesh shader pipeline: %s", ToStr(vkr).c_str());
      RDCERR("%s", ret.meshout.status.c_str());
      return;
    }

    // make copy of state to draw from
    VulkanRenderState modifiedstate = state;

    // bind created pipeline to partial replay state
    modifiedstate.graphics.pipeline = GetResID(taskPipe);

    VkCommandBuffer cmd = m_pDriver->GetNextCmd();

    if(cmd == VK_NULL_HANDLE)
    {
      m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
      m_pDriver->vkFreeMemory(m_Device, readbackTaskMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, readbackTaskBuffer, NULL);

      m_pDriver->vkDestroyPipeline(dev, taskPipe, NULL);
      return;
    }

    VkCommandBufferBeginInfo beginInfo = {VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, NULL,
                                          VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT};

    vkr = ObjDisp(dev)->BeginCommandBuffer(Unwrap(cmd), &beginInfo);
    CheckVkResult(vkr);

    // fill destination buffer with 0s to ensure unwritten vertices have sane data
    ObjDisp(dev)->CmdFillBuffer(Unwrap(cmd), Unwrap(taskBuffer), 0, taskBufSize, 0);

    VkBufferMemoryBarrier taskbufbarrier = {
        VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER,
        NULL,
        VK_ACCESS_TRANSFER_WRITE_BIT,
        VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT,
        VK_QUEUE_FAMILY_IGNORED,
        VK_QUEUE_FAMILY_IGNORED,
    };

    taskbufbarrier.buffer = Unwrap(taskBuffer);
    taskbufbarrier.size = taskBufSize;

    // wait for the above fill to finish.
    DoPipelineBarrier(cmd, 1, &taskbufbarrier);

    modifiedstate.subpassContents = VK_SUBPASS_CONTENTS_INLINE;
    modifiedstate.dynamicRendering.flags &= ~VK_RENDERING_CONTENTS_SECONDARY_COMMAND_BUFFERS_BIT;

    // do single draw
    modifiedstate.BeginRenderPassAndApplyState(m_pDriver, cmd, VulkanRenderState::BindGraphics,
                                               false);

    m_pDriver->ReplayDraw(cmd, action);

    modifiedstate.EndRenderPass(cmd);

    // wait for task output writing to finish
    taskbufbarrier.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_SHADER_WRITE_BIT;
    taskbufbarrier.dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT;

    DoPipelineBarrier(cmd, 1, &taskbufbarrier);

    VkBufferCopy bufcopy = {
        0,
        0,
        taskBufSize,
    };

    // copy to readback buffer
    ObjDisp(dev)->CmdCopyBuffer(Unwrap(cmd), Unwrap(taskBuffer), Unwrap(readbackTaskBuffer), 1,
                                &bufcopy);

    taskbufbarrier.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
    taskbufbarrier.dstAccessMask = VK_ACCESS_HOST_READ_BIT;
    taskbufbarrier.buffer = Unwrap(readbackTaskBuffer);

    // wait for copy to finish
    DoPipelineBarrier(cmd, 1, &taskbufbarrier);

    vkr = ObjDisp(dev)->EndCommandBuffer(Unwrap(cmd));
    CheckVkResult(vkr);

    // submit & flush so that we don't have to keep pipeline around for a while
    m_pDriver->SubmitCmds();
    m_pDriver->FlushQ();

    // delete pipeline
    m_pDriver->vkDestroyPipeline(dev, taskPipe, NULL);

    // readback task data
    const byte *taskData = NULL;
    vkr = m_pDriver->vkMapMemory(m_Device, readbackTaskMem, 0, VK_WHOLE_SIZE, 0, (void **)&taskData);
    CheckVkResult(vkr);
    if(vkr != VK_SUCCESS || !taskData)
    {
      m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
      m_pDriver->vkFreeMemory(m_Device, readbackTaskMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, readbackTaskBuffer, NULL);

      if(!taskData)
      {
        RDCERR("Manually reporting failed memory map");
        CheckVkResult(VK_ERROR_MEMORY_MAP_FAILED);
      }
      ret.meshout.status = ret.taskout.status = "Couldn't read back task output data from GPU";
      return;
    }

    VkMappedMemoryRange range = {
        VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE, NULL, readbackTaskMem, 0, VK_WHOLE_SIZE,
    };

    vkr = m_pDriver->vkInvalidateMappedMemoryRanges(m_Device, 1, &range);
    CheckVkResult(vkr);

    totalNumMeshlets = 0;
    const byte *taskDataBegin = taskData;

    cmd = m_pDriver->GetNextCmd();

    if(cmd == VK_NULL_HANDLE)
    {
      m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
      m_pDriver->vkFreeMemory(m_Device, readbackTaskMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, readbackTaskBuffer, NULL);
      return;
    }

    vkr = ObjDisp(dev)->BeginCommandBuffer(Unwrap(cmd), &beginInfo);
    CheckVkResult(vkr);

    for(uint32_t taskGroup = 0; taskGroup < totalNumTaskGroups; taskGroup++)
    {
      Vec4u meshDispatchSize = *(Vec4u *)taskData;
      RDCASSERT(meshDispatchSize.y <= 0xffff);
      RDCASSERT(meshDispatchSize.z <= 0xffff);

      // while we're going, we record writes into the real buffer with the cumulative sizes. This
      // should in theory be better than updating it via a buffer copy since the count should be
      // much smaller than the payload
      ObjDisp(dev)->CmdUpdateBuffer(Unwrap(cmd), Unwrap(taskBuffer),
                                    taskData - taskDataBegin + offsetof(Vec4u, w), 4,
                                    &totalNumMeshlets);

      totalNumMeshlets += meshDispatchSize.x * meshDispatchSize.y * meshDispatchSize.z;

      VulkanPostVSData::InstData i;
      i.taskDispatchSizeX = meshDispatchSize.x;
      i.taskDispatchSizeYZ.y = meshDispatchSize.y & 0xffff;
      i.taskDispatchSizeYZ.z = meshDispatchSize.z & 0xffff;
      taskDispatchSizes.push_back(i);

      taskData += sizeof(Vec4u) + taskPayloadSize;
    }

    m_pDriver->vkUnmapMemory(m_Device, readbackTaskMem);

    taskbufbarrier.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
    taskbufbarrier.dstAccessMask = VK_ACCESS_SHADER_READ_BIT;
    taskbufbarrier.buffer = Unwrap(taskBuffer);

    // wait for copy to finish
    DoPipelineBarrier(cmd, 1, &taskbufbarrier);

    vkr = ObjDisp(dev)->EndCommandBuffer(Unwrap(cmd));
    CheckVkResult(vkr);
  }

  // clean up temporary memories
  m_pDriver->vkDestroyBuffer(m_Device, readbackTaskBuffer, NULL);
  m_pDriver->vkFreeMemory(m_Device, readbackTaskMem, NULL);

  VkBuffer meshBuffer = VK_NULL_HANDLE, readbackBuffer = VK_NULL_HANDLE;
  VkDeviceMemory meshMem = VK_NULL_HANDLE, readbackMem = VK_NULL_HANDLE;

  VkDeviceSize bufSize = 0;

  uint32_t numViews = 1;

  if(state.dynamicRendering.active)
  {
    numViews = RDCMAX(numViews, Log2Ceil(state.dynamicRendering.viewMask + 1));
  }
  else
  {
    const VulkanCreationInfo::RenderPass &rp = creationInfo.m_RenderPass[state.GetRenderPass()];

    if(state.subpass < rp.subpasses.size())
    {
      numViews = RDCMAX(numViews, (uint32_t)rp.subpasses[state.subpass].multiviews.size());
    }
    else
    {
      RDCERR("Subpass is out of bounds to renderpass creation info");
    }
  }

  rdcarray<uint32_t> modSpirv = meshInfo.spirv.GetSPIRV();

  if(!Vulkan_Debug_PostVSDumpDirPath().empty())
    FileIO::WriteAll(Vulkan_Debug_PostVSDumpDirPath() + "/debug_postms_before.spv", modSpirv);

  OutMeshletLayout layout;

  AddMeshShaderOutputStores(*meshrefl, meshShad.specialization, *meshShad.patchData,
                            meshShad.entryPoint, bufSpecConstant, modSpirv, taskDataAddress != 0,
                            layout);

  if(!Vulkan_Debug_PostVSDumpDirPath().empty())
    FileIO::WriteAll(Vulkan_Debug_PostVSDumpDirPath() + "/debug_postms_after.spv", modSpirv);

  if(totalNumMeshlets > 0)
  {
    // now that we know the stride, create buffer of sufficient size for the worst case (maximum
    // generation) of the meshlets
    VkBufferCreateInfo bufInfo = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO};

    bufSize = bufInfo.size = layout.meshletByteSize * totalNumMeshlets;

    bufInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
    bufInfo.usage |= VK_BUFFER_USAGE_TRANSFER_DST_BIT;
    bufInfo.usage |= VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT;

    vkr = m_pDriver->vkCreateBuffer(dev, &bufInfo, NULL, &meshBuffer);
    CheckVkResult(vkr);

    bufInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT;

    vkr = m_pDriver->vkCreateBuffer(dev, &bufInfo, NULL, &readbackBuffer);
    CheckVkResult(vkr);

    VkMemoryRequirements mrq = {0};
    m_pDriver->vkGetBufferMemoryRequirements(dev, meshBuffer, &mrq);

    VkMemoryAllocateInfo allocInfo = {
        VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO,
        NULL,
        mrq.size,
        m_pDriver->GetGPULocalMemoryIndex(mrq.memoryTypeBits),
    };

    allocInfo.pNext = &memFlags;

    vkr = m_pDriver->vkAllocateMemory(dev, &allocInfo, NULL, &meshMem);

    if(vkr == VK_ERROR_OUT_OF_DEVICE_MEMORY || vkr == VK_ERROR_OUT_OF_HOST_MEMORY)
    {
      m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, readbackBuffer, NULL);

      RDCWARN("Failed to allocate %llu bytes for output", mrq.size);
      ret.meshout.status = StringFormat::Fmt("Failed to allocate %llu bytes", mrq.size);
      return;
    }

    CheckVkResult(vkr);

    vkr = m_pDriver->vkBindBufferMemory(dev, meshBuffer, meshMem, 0);
    CheckVkResult(vkr);

    m_pDriver->vkGetBufferMemoryRequirements(dev, readbackBuffer, &mrq);

    allocInfo.pNext = NULL;
    allocInfo.memoryTypeIndex = m_pDriver->GetReadbackMemoryIndex(mrq.memoryTypeBits);

    vkr = m_pDriver->vkAllocateMemory(dev, &allocInfo, NULL, &readbackMem);

    if(vkr == VK_ERROR_OUT_OF_DEVICE_MEMORY || vkr == VK_ERROR_OUT_OF_HOST_MEMORY)
    {
      m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
      m_pDriver->vkFreeMemory(m_Device, meshMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, readbackBuffer, NULL);

      RDCWARN("Failed to allocate %llu bytes for readback", mrq.size);
      ret.meshout.status = StringFormat::Fmt("Failed to allocate %llu bytes", mrq.size);
      return;
    }

    CheckVkResult(vkr);

    vkr = m_pDriver->vkBindBufferMemory(dev, readbackBuffer, readbackMem, 0);
    CheckVkResult(vkr);

    // register address as specialisation constant

    // ensure we're 64-bit aligned first
    meshSpecData.resize(AlignUp(meshSpecData.size(), (size_t)8));

    VkBufferDeviceAddressInfo getAddressInfo = {
        VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO,
        NULL,
        meshBuffer,
    };

    VkDeviceAddress address = m_pDriver->vkGetBufferDeviceAddress(dev, &getAddressInfo);

    VkSpecializationMapEntry entry;
    entry.offset = (uint32_t)meshSpecData.size();
    entry.constantID = bufSpecConstant;
    entry.size = sizeof(uint64_t);
    meshSpecEntries.push_back(entry);
    meshSpecData.append((const byte *)&address, sizeof(uint64_t));
  }

  VkSpecializationInfo meshSpecInfo = {};
  meshSpecInfo.dataSize = meshSpecData.size();
  meshSpecInfo.pData = meshSpecData.data();
  meshSpecInfo.mapEntryCount = (uint32_t)meshSpecEntries.size();
  meshSpecInfo.pMapEntries = meshSpecEntries.data();

  VkSpecializationInfo taskSpecInfo = {};
  taskSpecInfo.dataSize = taskSpecData.size();
  taskSpecInfo.pData = taskSpecData.data();
  taskSpecInfo.mapEntryCount = (uint32_t)taskSpecEntries.size();
  taskSpecInfo.pMapEntries = taskSpecEntries.data();

  // create mesh shader with modified code
  VkShaderModuleCreateInfo moduleCreateInfo = {
      VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, NULL,         0,
      modSpirv.size() * sizeof(uint32_t),          &modSpirv[0],
  };

  VkShaderModule module, taskFeedModule = VK_NULL_HANDLE;
  vkr = m_pDriver->vkCreateShaderModule(dev, &moduleCreateInfo, NULL, &module);
  CheckVkResult(vkr);

  if(taskDataAddress != 0)
  {
    const VulkanCreationInfo::Pipeline::Shader &taskShad =
        pipeInfo.shaders[(size_t)ShaderStage::Task];

    const VulkanCreationInfo::ShaderModule &taskInfo = creationInfo.m_ShaderModule[taskShad.module];

    modSpirv = taskInfo.spirv.GetSPIRV();

    ConvertToFixedTaskFeeder(taskShad.specialization, taskShad.entryPoint, bufSpecConstant + 1,
                             taskPayloadSize, modSpirv);

    if(!Vulkan_Debug_PostVSDumpDirPath().empty())
      FileIO::WriteAll(Vulkan_Debug_PostVSDumpDirPath() + "/debug_postts_feeder.spv", modSpirv);

    moduleCreateInfo.pCode = modSpirv.data();
    moduleCreateInfo.codeSize = modSpirv.byteSize();

    vkr = m_pDriver->vkCreateShaderModule(dev, &moduleCreateInfo, NULL, &taskFeedModule);
    CheckVkResult(vkr);
  }

  for(uint32_t s = 0; s < pipeCreateInfo.stageCount; s++)
  {
    if(pipeCreateInfo.pStages[s].stage == VK_SHADER_STAGE_MESH_BIT_EXT)
    {
      VkPipelineShaderStageCreateInfo &meshStage =
          (VkPipelineShaderStageCreateInfo &)pipeCreateInfo.pStages[s];
      meshStage.module = module;
      meshStage.pSpecializationInfo = &meshSpecInfo;
    }
    else if(pipeCreateInfo.pStages[s].stage == VK_SHADER_STAGE_TASK_BIT_EXT)
    {
      VkPipelineShaderStageCreateInfo &taskStage =
          (VkPipelineShaderStageCreateInfo &)pipeCreateInfo.pStages[s];
      taskStage.module = taskFeedModule;
      taskStage.pSpecializationInfo = &taskSpecInfo;
    }
  }

  // create new pipeline
  VkPipeline pipe;
  vkr = m_pDriver->vkCreateGraphicsPipelines(m_Device, VK_NULL_HANDLE, 1, &pipeCreateInfo, NULL,
                                             &pipe);

  // delete shader/shader module
  m_pDriver->vkDestroyShaderModule(dev, module, NULL);

  // delete shader/shader module
  m_pDriver->vkDestroyShaderModule(dev, taskFeedModule, NULL);

  if(vkr != VK_SUCCESS)
  {
    m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
    m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
    m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
    m_pDriver->vkFreeMemory(m_Device, meshMem, NULL);
    m_pDriver->vkDestroyBuffer(m_Device, readbackBuffer, NULL);
    m_pDriver->vkFreeMemory(m_Device, readbackMem, NULL);

    ret.meshout.status =
        StringFormat::Fmt("Failed to create patched mesh shader pipeline: %s", ToStr(vkr).c_str());
    RDCERR("%s", ret.meshout.status.c_str());
    return;
  }

  // make copy of state to draw from
  VulkanRenderState modifiedstate = state;

  // bind created pipeline to partial replay state
  modifiedstate.graphics.pipeline = GetResID(pipe);

  if(totalNumMeshlets > 0)
  {
    VkCommandBuffer cmd = m_pDriver->GetNextCmd();

    if(cmd == VK_NULL_HANDLE)
    {
      m_pDriver->vkDestroyPipeline(dev, pipe, NULL);
      m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
      m_pDriver->vkFreeMemory(m_Device, meshMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, readbackBuffer, NULL);
      m_pDriver->vkFreeMemory(m_Device, readbackMem, NULL);

      return;
    }

    VkCommandBufferBeginInfo beginInfo = {VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, NULL,
                                          VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT};

    vkr = ObjDisp(dev)->BeginCommandBuffer(Unwrap(cmd), &beginInfo);
    CheckVkResult(vkr);

    // fill destination buffer with 0s to ensure unwritten vertices have sane data
    ObjDisp(dev)->CmdFillBuffer(Unwrap(cmd), Unwrap(meshBuffer), 0, bufSize, 0);

    VkBufferMemoryBarrier meshbufbarrier = {
        VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER,
        NULL,
        VK_ACCESS_TRANSFER_WRITE_BIT,
        VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT,
        VK_QUEUE_FAMILY_IGNORED,
        VK_QUEUE_FAMILY_IGNORED,
    };

    meshbufbarrier.buffer = Unwrap(meshBuffer);
    meshbufbarrier.size = bufSize;

    // wait for the above fill to finish.
    DoPipelineBarrier(cmd, 1, &meshbufbarrier);

    modifiedstate.subpassContents = VK_SUBPASS_CONTENTS_INLINE;
    modifiedstate.dynamicRendering.flags &= ~VK_RENDERING_CONTENTS_SECONDARY_COMMAND_BUFFERS_BIT;

    // do single draw
    modifiedstate.BeginRenderPassAndApplyState(m_pDriver, cmd, VulkanRenderState::BindGraphics,
                                               false);

    m_pDriver->ReplayDraw(cmd, action);

    modifiedstate.EndRenderPass(cmd);

    // wait for mesh output writing to finish
    meshbufbarrier.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT | VK_ACCESS_SHADER_WRITE_BIT;
    meshbufbarrier.dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT;

    DoPipelineBarrier(cmd, 1, &meshbufbarrier);

    VkBufferCopy bufcopy = {
        0,
        0,
        bufSize,
    };

    // copy to readback buffer
    ObjDisp(dev)->CmdCopyBuffer(Unwrap(cmd), Unwrap(meshBuffer), Unwrap(readbackBuffer), 1, &bufcopy);

    meshbufbarrier.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT;
    meshbufbarrier.dstAccessMask = VK_ACCESS_HOST_READ_BIT;
    meshbufbarrier.buffer = Unwrap(readbackBuffer);

    // wait for copy to finish
    DoPipelineBarrier(cmd, 1, &meshbufbarrier);

    vkr = ObjDisp(dev)->EndCommandBuffer(Unwrap(cmd));
    CheckVkResult(vkr);

    // submit & flush so that we don't have to keep pipeline around for a while
    m_pDriver->SubmitCmds();
    m_pDriver->FlushQ();
  }

  // delete pipeline
  m_pDriver->vkDestroyPipeline(dev, pipe, NULL);

  rdcarray<VulkanPostVSData::InstData> meshletOffsets;

  uint32_t baseIndex = 0;

  rdcarray<uint32_t> rebasedIndices;
  bytebuf compactedVertices;

  float nearp = 0.1f;
  float farp = 100.0f;

  uint32_t totalVerts = 0, totalPrims = 0;
  uint32_t totalVertStride = 0;
  uint32_t totalPrimStride = 0;

  if(totalNumMeshlets > 0)
  {
    // readback mesh data
    const byte *meshletData = NULL;
    vkr = m_pDriver->vkMapMemory(m_Device, readbackMem, 0, VK_WHOLE_SIZE, 0, (void **)&meshletData);
    CheckVkResult(vkr);
    if(vkr != VK_SUCCESS || !meshletData)
    {
      if(!meshletData)
      {
        RDCERR("Manually reporting failed memory map");
        CheckVkResult(VK_ERROR_MEMORY_MAP_FAILED);
      }
      m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
      m_pDriver->vkFreeMemory(m_Device, meshMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, readbackBuffer, NULL);
      m_pDriver->vkFreeMemory(m_Device, readbackMem, NULL);
      ret.meshout.status = "Couldn't read back mesh output data from GPU";
      return;
    }

    VkMappedMemoryRange range = {
        VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE, NULL, readbackMem, 0, VK_WHOLE_SIZE,
    };

    vkr = m_pDriver->vkInvalidateMappedMemoryRanges(m_Device, 1, &range);
    CheckVkResult(vkr);

    // do a super quick sum of the number of verts and prims
    for(uint32_t m = 0; m < totalNumMeshlets; m++)
    {
      Vec4u *counts = (Vec4u *)(meshletData + m * layout.meshletByteSize);
      totalVerts += counts->x;
      totalPrims += counts->y;
    }

    if(totalPrims == 0)
    {
      m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
      m_pDriver->vkFreeMemory(m_Device, meshMem, NULL);
      m_pDriver->vkDestroyBuffer(m_Device, readbackBuffer, NULL);
      m_pDriver->vkFreeMemory(m_Device, readbackMem, NULL);
      ret.meshout.status = "No mesh output data generated by GPU";
      return;
    }

    for(size_t o = 0; o < layout.sigLocations.size(); o++)
    {
      if(meshrefl->outputSignature[o].systemValue == ShaderBuiltin::OutputIndices)
        continue;

      const SigParameter &sig = meshrefl->outputSignature[o];
      const uint32_t byteSize = VarTypeByteSize(sig.varType) * sig.compCount;

      if(meshrefl->outputSignature[o].perPrimitiveRate)
        totalPrimStride += byteSize;
      else
        totalVertStride += byteSize;
    }

    rdcarray<uint32_t> sigOffsets;
    sigOffsets.resize(meshrefl->outputSignature.size());

    {
      uint32_t vertOffset = 0;
      uint32_t primOffset = 0;
      for(size_t o = 0; o < meshrefl->outputSignature.size(); o++)
      {
        const SigParameter &sig = meshrefl->outputSignature[o];
        const uint32_t byteSize = VarTypeByteSize(sig.varType) * sig.compCount;

        if(sig.systemValue == ShaderBuiltin::OutputIndices)
          continue;

        // move position to the front when compacting
        if(sig.systemValue == ShaderBuiltin::Position)
        {
          RDCASSERT(!sig.perPrimitiveRate);
          sigOffsets[o] = 0;
          vertOffset += byteSize;

          // shift all previous signatures up
          for(size_t prev = 0; prev < o; prev++)
            sigOffsets[prev] += byteSize;

          continue;
        }

        if(sig.perPrimitiveRate)
        {
          sigOffsets[o] = primOffset;
          primOffset += byteSize;
        }
        else
        {
          sigOffsets[o] = vertOffset;
          vertOffset += byteSize;
        }
      }

      RDCASSERT(vertOffset == totalVertStride);
      RDCASSERT(primOffset == totalPrimStride);
    }

    // now we reorganise and compact the data.
    // Some arrays will need to be decomposed (any non-struct outputs will be SoA and we want full
    // AoS). We also rebase indices so they can be used as a contiguous index buffer

    rebasedIndices.reserve(totalPrims * layout.indexCountPerPrim);
    compactedVertices.resize(totalVerts * totalVertStride + totalPrims * totalPrimStride);

    byte *vertData = compactedVertices.begin();
    byte *primData = vertData + totalVerts * totalVertStride;

    // calculate near/far as we're going
    bool found = false;
    Vec4f pos0;

    for(uint32_t meshlet = 0; meshlet < totalNumMeshlets; meshlet++)
    {
      Vec4u *counts = (Vec4u *)meshletData;
      const uint32_t numVerts = counts->x;
      const uint32_t numPrims = counts->y;

      const uint32_t padding = counts->z;
      const uint32_t padding2 = counts->w;
      RDCASSERTEQUAL(padding, 0);
      RDCASSERTEQUAL(padding2, 0);

      if(numVerts > layout.vertArrayLength)
      {
        RDCERR("Meshlet returned invalid vertex count %u with declared max %u", numVerts,
               layout.vertArrayLength);
        ret.meshout.status = "Got corrupted mesh output data from GPU";
      }

      if(numPrims > layout.primArrayLength)
      {
        RDCERR("Meshlet returned invalid primitive count %u with declared max %u", numPrims,
               layout.primArrayLength);
        ret.meshout.status = "Got corrupted mesh output data from GPU";
      }

      if(!ret.meshout.status.empty())
      {
        m_pDriver->vkFreeMemory(m_Device, taskMem, NULL);
        m_pDriver->vkDestroyBuffer(m_Device, taskBuffer, NULL);
        m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
        m_pDriver->vkFreeMemory(m_Device, meshMem, NULL);
        m_pDriver->vkDestroyBuffer(m_Device, readbackBuffer, NULL);
        m_pDriver->vkFreeMemory(m_Device, readbackMem, NULL);
        return;
      }

      VulkanPostVSData::InstData meshletOffsetData;
      meshletOffsetData.numIndices = numPrims * layout.indexCountPerPrim;
      meshletOffsetData.numVerts = numVerts;
      meshletOffsets.push_back(meshletOffsetData);

      uint32_t *indices = (uint32_t *)(counts + 2);

      for(uint32_t p = 0; p < numPrims; p++)
      {
        for(uint32_t idx = 0; idx < layout.indexCountPerPrim; idx++)
          rebasedIndices.push_back(indices[p * layout.indexCountPerPrim + idx] + baseIndex);
      }

      for(size_t o = 0; o < meshrefl->outputSignature.size(); o++)
      {
        const SigParameter &sig = meshrefl->outputSignature[o];
        const uint32_t byteSize = VarTypeByteSize(sig.varType) * sig.compCount;

        if(sig.systemValue == ShaderBuiltin::OutputIndices)
          continue;

        if(meshrefl->outputSignature[o].perPrimitiveRate)
        {
          for(uint32_t p = 0; p < numPrims; p++)
          {
            memcpy(primData + sigOffsets[o] + totalPrimStride * p,
                   meshletData + layout.sigLocations[o].offset + layout.sigLocations[o].stride * p,
                   byteSize);
          }
        }
        else
        {
          for(uint32_t v = 0; v < numVerts; v++)
          {
            byte *dst = vertData + sigOffsets[o] + totalVertStride * v;

            memcpy(dst,
                   meshletData + layout.sigLocations[o].offset + layout.sigLocations[o].stride * v,
                   byteSize);

            if(!found && sig.systemValue == ShaderBuiltin::Position)
            {
              Vec4f pos = *(Vec4f *)dst;

              if(v == 0)
              {
                pos0 = pos;
              }
              else
              {
                DeriveNearFar(pos, pos0, nearp, farp, found);
              }
            }
          }
        }
      }

      baseIndex += numVerts;
      meshletData += layout.meshletByteSize;
      vertData += totalVertStride * numVerts;
      primData += totalPrimStride * numPrims;
    }

    RDCASSERT(vertData == compactedVertices.begin() + totalVerts * totalVertStride);
    RDCASSERT(primData == compactedVertices.end());

    // if we didn't find any near/far plane, all z's and w's were identical.
    // If the z is positive and w greater for the first element then we detect this projection as
    // reversed z with infinite far plane
    if(!found && pos0.z > 0.0f && pos0.w > pos0.z)
    {
      nearp = pos0.z;
      farp = FLT_MAX;
    }

    m_pDriver->vkUnmapMemory(m_Device, readbackMem);
  }

  // clean up temporary memories
  m_pDriver->vkDestroyBuffer(m_Device, readbackBuffer, NULL);
  m_pDriver->vkFreeMemory(m_Device, readbackMem, NULL);

  // clean up temporary memories
  m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
  m_pDriver->vkFreeMemory(m_Device, meshMem, NULL);

  // fill out m_PostVS.Data
  if(layout.indexCountPerPrim == 3)
    ret.meshout.topo = Topology::TriangleList;
  else if(layout.indexCountPerPrim == 2)
    ret.meshout.topo = Topology::LineList;
  else if(layout.indexCountPerPrim == 1)
    ret.meshout.topo = Topology::PointList;

  if(totalNumMeshlets > 0)
  {
    VkBufferCreateInfo bufInfo = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO};

    bufInfo.size = AlignUp16(compactedVertices.byteSize()) + rebasedIndices.byteSize();

    bufInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT;
    bufInfo.usage |= VK_BUFFER_USAGE_TRANSFER_DST_BIT;
    bufInfo.usage |= VK_BUFFER_USAGE_VERTEX_BUFFER_BIT;
    bufInfo.usage |= VK_BUFFER_USAGE_INDEX_BUFFER_BIT;

    vkr = m_pDriver->vkCreateBuffer(dev, &bufInfo, NULL, &meshBuffer);
    CheckVkResult(vkr);

    VkMemoryRequirements mrq = {0};
    m_pDriver->vkGetBufferMemoryRequirements(dev, meshBuffer, &mrq);

    VkMemoryAllocateInfo allocInfo = {
        VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO,
        NULL,
        mrq.size,
        m_pDriver->GetUploadMemoryIndex(mrq.memoryTypeBits),
    };

    vkr = m_pDriver->vkAllocateMemory(dev, &allocInfo, NULL, &meshMem);

    if(vkr == VK_ERROR_OUT_OF_DEVICE_MEMORY || vkr == VK_ERROR_OUT_OF_HOST_MEMORY)
    {
      m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
      RDCWARN("Failed to allocate %llu bytes for output", mrq.size);
      ret.meshout.status = StringFormat::Fmt("Failed to allocate %llu bytes", mrq.size);
      return;
    }

    CheckVkResult(vkr);

    vkr = m_pDriver->vkBindBufferMemory(dev, meshBuffer, meshMem, 0);
    CheckVkResult(vkr);

    byte *uploadData = NULL;
    vkr = m_pDriver->vkMapMemory(m_Device, meshMem, 0, VK_WHOLE_SIZE, 0, (void **)&uploadData);
    CheckVkResult(vkr);
    if(vkr != VK_SUCCESS || !uploadData)
    {
      m_pDriver->vkDestroyBuffer(m_Device, meshBuffer, NULL);
      m_pDriver->vkFreeMemory(m_Device, meshMem, NULL);
      if(!uploadData)
      {
        RDCERR("Manually reporting failed memory map");
        CheckVkResult(VK_ERROR_MEMORY_MAP_FAILED);
      }
      ret.meshout.status = "Couldn't upload mesh output data to GPU";
      return;
    }

    memcpy(uploadData, compactedVertices.data(), compactedVertices.byteSize());
    memcpy(uploadData + AlignUp16(compactedVertices.byteSize()), rebasedIndices.data(),
           rebasedIndices.byteSize());

    VkMappedMemoryRange range = {
        VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE, NULL, meshMem, 0, VK_WHOLE_SIZE,
    };

    vkr = m_pDriver->vkFlushMappedMemoryRanges(m_Device, 1, &range);
    CheckVkResult(vkr);

    m_pDriver->vkUnmapMemory(m_Device, meshMem);
  }

  ret.taskout.buf = taskBuffer;
  ret.taskout.bufmem = taskMem;

  if(!pipeInfo.shaders[6].refl)
    ret.taskout.status = "No task shader bound";

  ret.taskout.baseVertex = 0;

  // TODO handle multiple views
  ret.taskout.numViews = 1;

  ret.taskout.dispatchSize = action.dispatchDimension;

  ret.taskout.vertStride = taskPayloadSize + sizeof(Vec4u);
  ret.taskout.nearPlane = 0.0f;
  ret.taskout.farPlane = 1.0f;

  ret.taskout.primStride = 0;
  ret.taskout.primOffset = 0;

  ret.taskout.useIndices = false;
  ret.taskout.numVerts = totalNumTaskGroups;
  ret.taskout.instData = taskDispatchSizes;

  ret.taskout.instStride = 0;

  ret.taskout.idxbuf = VK_NULL_HANDLE;
  ret.taskout.idxOffset = 0;
  ret.taskout.idxbufmem = VK_NULL_HANDLE;
  ret.taskout.idxFmt = VK_INDEX_TYPE_UINT32;

  ret.taskout.hasPosOut = false;
  ret.taskout.flipY = state.views.empty() ? false : state.views[0].height < 0.0f;

  ret.meshout.buf = meshBuffer;
  ret.meshout.bufmem = meshMem;

  ret.meshout.baseVertex = 0;

  // TODO handle multiple views
  ret.meshout.numViews = 1;

  ret.meshout.dispatchSize = action.dispatchDimension;

  ret.meshout.vertStride = totalVertStride;
  ret.meshout.nearPlane = nearp;
  ret.meshout.farPlane = farp;

  ret.meshout.primStride = totalPrimStride;
  ret.meshout.primOffset = totalVertStride * totalVerts;

  ret.meshout.useIndices = true;
  ret.meshout.numVerts = totalPrims * layout.indexCountPerPrim;
  ret.meshout.instData = meshletOffsets;

  ret.meshout.instStride = 0;

  ret.meshout.idxbuf = meshBuffer;
  ret.meshout.idxOffset = AlignUp16(compactedVertices.byteSize());
  ret.meshout.idxbufmem = VK_NULL_HANDLE;
  ret.meshout.idxFmt = VK_INDEX_TYPE_UINT32;

  ret.meshout.hasPosOut = true;
  ret.meshout.flipY = state.views.empty() ? false : state.views[0].height < 0.0f;
}