static void AddDXILMeshShaderOutputStores()

in renderdoc/driver/d3d12/d3d12_postvs.cpp [1101:1958]


static void AddDXILMeshShaderOutputStores(uint32_t ampPayloadSize, const DXBC::DXBCContainer *dxbc,
                                          uint32_t space, bool readAmpOffset,
                                          rdcfixedarray<uint32_t, 3> dispatchDim,
                                          OutDXILMeshletLayout &layout, bytebuf &editedBlob)
{
  using namespace DXIL;

  ProgramEditor editor(dxbc, editedBlob);

  bool isShaderModel6_6OrAbove =
      dxbc->m_Version.Major > 6 || (dxbc->m_Version.Major == 6 && dxbc->m_Version.Minor >= 6);

  const Type *i32 = editor.GetInt32Type();
  const Type *i8 = editor.GetInt8Type();
  const Type *i1 = editor.GetBoolType();
  const Type *voidType = editor.GetVoidType();

  const Type *handleType = editor.CreateNamedStructType(
      "dx.types.Handle", {editor.CreatePointerType(i8, Type::PointerAddrSpace::Default)});

  const Function *createHandle = NULL;
  const Function *createHandleFromBinding = NULL;
  const Function *annotateHandle = NULL;

  // reading from a binding uses a different function in SM6.6+
  if(isShaderModel6_6OrAbove)
  {
    const Type *resBindType = editor.CreateNamedStructType("dx.types.ResBind", {i32, i32, i32, i8});
    createHandleFromBinding = editor.DeclareFunction("dx.op.createHandleFromBinding", handleType,
                                                     {i32, resBindType, i32, i1},
                                                     Attribute::NoUnwind | Attribute::ReadNone);

    const Type *resourcePropertiesType =
        editor.CreateNamedStructType("dx.types.ResourceProperties", {i32, i32});
    annotateHandle = editor.DeclareFunction("dx.op.annotateHandle", handleType,
                                            {i32, handleType, resourcePropertiesType},
                                            Attribute::NoUnwind | Attribute::ReadNone);
  }
  else if(!createHandle && !isShaderModel6_6OrAbove)
  {
    createHandle = editor.DeclareFunction("dx.op.createHandle", handleType, {i32, i8, i32, i32, i1},
                                          Attribute::NoUnwind | Attribute::ReadOnly);
  }

  const Function *flattenedThreadIdInGroup = editor.DeclareFunction(
      "dx.op.flattenedThreadIdInGroup.i32", i32, {i32}, Attribute::NoUnwind | Attribute::ReadNone);
  const Function *groupId = editor.DeclareFunction("dx.op.groupId.i32", i32, {i32, i32},
                                                   Attribute::NoUnwind | Attribute::ReadNone);

  const Function *getMeshPayload = editor.GetFunctionByPrefix("dx.op.getMeshPayload");

  const Function *setMeshOutputCounts = editor.DeclareFunction(
      "dx.op.setMeshOutputCounts", voidType, {i32, i32, i32}, Attribute::NoUnwind);
  const Function *emitIndices = editor.DeclareFunction(
      "dx.op.emitIndices", voidType, {i32, i32, i32, i32, i32}, Attribute::NoUnwind);

  // declare the resource, this happens purely in metadata but we need to store the slot
  uint32_t regSlot = 0;
  Metadata *reslist = NULL;
  {
    const Type *rw = editor.CreateNamedStructType("struct.RWByteAddressBuffer", {i32});
    const Type *rwptr = editor.CreatePointerType(rw, Type::PointerAddrSpace::Default);

    Metadata *resources = editor.CreateNamedMetadata("dx.resources");
    if(resources->children.empty())
      resources->children.push_back(editor.CreateMetadata());

    reslist = resources->children[0];

    if(reslist->children.empty())
      reslist->children.resize(4);

    Metadata *uavs = reslist->children[1];
    // if there isn't a UAV list, create an empty one so we can add our own
    if(!uavs)
      uavs = reslist->children[1] = editor.CreateMetadata();

    for(size_t i = 0; i < uavs->children.size(); i++)
    {
      // each UAV child should have a fixed format, [0] is the reg ID and I think this should always
      // be == the index
      const Metadata *uav = uavs->children[i];
      const Constant *slot = cast<Constant>(uav->children[(size_t)ResField::ID]->value);

      if(!slot)
      {
        RDCWARN("Unexpected non-constant slot ID in UAV");
        continue;
      }

      RDCASSERT(slot->getU32() == i);

      uint32_t id = slot->getU32();
      regSlot = RDCMAX(id + 1, regSlot);
    }

    Constant rwundef;
    rwundef.type = rwptr;
    rwundef.setUndef(true);

    // create the new UAV record
    Metadata *uav = editor.CreateMetadata();
    uav->children = {
        editor.CreateConstantMetadata(regSlot),
        editor.CreateConstantMetadata(editor.CreateConstant(rwundef)),
        editor.CreateConstantMetadata(""),
        editor.CreateConstantMetadata(space),
        editor.CreateConstantMetadata(0U),                                   // reg base
        editor.CreateConstantMetadata(1U),                                   // reg count
        editor.CreateConstantMetadata(uint32_t(ResourceKind::RawBuffer)),    // shape
        editor.CreateConstantMetadata(false),                                // globally coherent
        editor.CreateConstantMetadata(false),                                // hidden counter
        editor.CreateConstantMetadata(false),                                // raster order
        NULL,                                                                // UAV tags
    };

    uavs->children.push_back(uav);
  }

  rdcstr entryName;

  // add the entry point tags
  bool hadPayload = false;

  Metadata *outSig = NULL, *primOutSig = NULL;
  {
    Metadata *entryPoints = editor.GetMetadataByName("dx.entryPoints");

    if(!entryPoints)
    {
      RDCERR("Couldn't find entry point list");
      return;
    }

    // TODO select the entry point for multiple entry points? RT only for now
    Metadata *entry = entryPoints->children[0];

    entryName = entry->children[1]->str;

    Metadata *taglist = entry->children[4];
    if(!taglist)
      taglist = entry->children[4] = editor.CreateMetadata();

    Metadata *sigs = entry->children[2];
    outSig = sigs->children[1];
    primOutSig = sigs->children[2];

    // find existing shader flags tag, if there is one
    Metadata *shaderFlagsTag = NULL;
    Metadata *shaderFlagsData = NULL;
    Metadata *meshData = NULL;
    size_t flagsIndex = 0;
    for(size_t t = 0; taglist && t < taglist->children.size(); t += 2)
    {
      RDCASSERT(taglist->children[t]->isConstant);
      if(cast<Constant>(taglist->children[t]->value)->getU32() ==
         (uint32_t)ShaderEntryTag::ShaderFlags)
      {
        shaderFlagsTag = taglist->children[t];
        shaderFlagsData = taglist->children[t + 1];
        flagsIndex = t + 1;
      }
      else if(cast<Constant>(taglist->children[t]->value)->getU32() == (uint32_t)ShaderEntryTag::Mesh)
      {
        meshData = taglist->children[t + 1];
      }
    }

    uint32_t shaderFlagsValue =
        shaderFlagsData ? cast<Constant>(shaderFlagsData->value)->getU32() : 0U;

    // raw and structured buffers
    shaderFlagsValue |= 0x10;

    // UAVs on non-PS/CS stages
    shaderFlagsValue |= 0x10000;

    // (re-)create shader flags tag
    Type *i64 = editor.CreateScalarType(Type::Int, 64);
    shaderFlagsData =
        editor.CreateConstantMetadata(editor.CreateConstant(Constant(i64, shaderFlagsValue)));

    // if we didn't have a shader tags entry at all, create the metadata node for the shader flags
    // tag
    if(!shaderFlagsTag)
      shaderFlagsTag = editor.CreateConstantMetadata((uint32_t)ShaderEntryTag::ShaderFlags);

    // if we had a tag already, we can just re-use that tag node and replace the data node.
    // Otherwise we need to add both, and we insert them first
    if(flagsIndex)
    {
      taglist->children[flagsIndex] = shaderFlagsData;
    }
    else
    {
      taglist->children.insert(0, shaderFlagsTag);
      taglist->children.insert(1, shaderFlagsData);
    }

    // set reslist and taglist in case they were null before
    entry->children[3] = reslist;
    entry->children[4] = taglist;

    // patch payload size in mesh tags if we're reading from amplification shader
    if(readAmpOffset)
    {
      uint32_t payloadSize = cast<Constant>(meshData->children[4]->value)->getU32();
      // DXIL payload can't be empty, so if the previous size was non-zero we had one previously
      hadPayload = payloadSize != 0;

      // if the amplification shader declares a payload, but mesh shader doesn't, we need to be sure
      // we match them in size for validation
      if(!hadPayload && ampPayloadSize != 0)
        payloadSize = ampPayloadSize;

      // if the mesh shader did have a payload, these sizes should match!
      RDCASSERTEQUAL(payloadSize, ampPayloadSize);

      payloadSize += 16;
      meshData->children[4] = editor.CreateConstantMetadata(payloadSize);
      editor.SetMSPayloadSize(payloadSize);
    }

    // if the topology (child [3]) is 1, then it's lines, otherwise triangles
    layout.indexCountPerPrim = cast<Constant>(meshData->children[3]->value)->getU32() == 1 ? 2 : 3;

    layout.vertArrayLength = cast<Constant>(meshData->children[1]->value)->getU32();
    layout.primArrayLength = cast<Constant>(meshData->children[2]->value)->getU32();
  }

  // get the editor to patch PSV0 with our extra UAV
  editor.RegisterUAV(DXILResourceType::ByteAddressUAV, space, 0, 0, ResourceKind::RawBuffer);

  Function *f = editor.GetFunctionByName(entryName);

  if(!f)
  {
    RDCERR("Couldn't find entry point function '%s'", entryName.c_str());
    return;
  }

  Type *payloadType = NULL;
  if(hadPayload)
  {
    if(getMeshPayload)
    {
      // if we had a payload and it was loaded, seek the dx.op.getMeshPayload to find its type
      for(size_t i = 0; i < f->instructions.size(); i++)
      {
        const Instruction &inst = *f->instructions[i];

        if(inst.op == Operation::Call && inst.getFuncCall()->name == getMeshPayload->name)
        {
          payloadType = (Type *)inst.type;

          RDCASSERT(payloadType->type == Type::Pointer);
          payloadType = (Type *)payloadType->inner;

          payloadType->members.append({i32, i32, i32, i32});

          break;
        }
      }
    }
    else
    {
      // if we had a payload declared but it wasn't ever fetched, there will be no function or type.
      // We create a synthetic type of the right size then patch it

      rdcarray<const Type *> members;
      for(uint32_t i = 0; i < ampPayloadSize / sizeof(uint32_t); i++)
        members.push_back(i32);

      // unclear if HLSL allows non-4byte aligned types
      RDCASSERT((ampPayloadSize % sizeof(uint32_t)) == 0);

      members.append({i32, i32, i32, i32});

      // no payload before. We get to make up our own!
      payloadType = editor.CreateNamedStructType("struct.payload_t", members);

      const Type *payloadPtrType =
          editor.CreatePointerType(payloadType, Type::PointerAddrSpace::Default);

      getMeshPayload = editor.DeclareFunction("dx.op.getMeshPayload.struct.payload_t", payloadPtrType,
                                              {i32}, Attribute::NoUnwind | Attribute::ReadOnly);
    }
  }
  else if(readAmpOffset)
  {
    // no payload before. We get to make up our own!
    payloadType = editor.CreateNamedStructType("struct.payload_t", {i32, i32, i32, i32});

    const Type *payloadPtrType =
        editor.CreatePointerType(payloadType, Type::PointerAddrSpace::Default);

    getMeshPayload = editor.DeclareFunction("dx.op.getMeshPayload.struct.payload_t", payloadPtrType,
                                            {i32}, Attribute::NoUnwind | Attribute::ReadOnly);
  }

  if(readAmpOffset)
  {
    RDCASSERT(payloadType && payloadType->type == Type::Struct);
  }

  uint32_t byteCounter = 0;

  layout.sigLocations.resize((outSig ? outSig->children.size() : 0) +
                             (primOutSig ? primOutSig->children.size() : 0));
  size_t firstPrimOutput = (outSig ? outSig->children.size() : 0);

  for(size_t i = 0; outSig && i < outSig->children.size(); i++)
  {
    OutDXILSigLocation &loc = layout.sigLocations[i];

    Metadata *sigMeta = outSig->children[i];

    uint32_t semantic = cast<Constant>(sigMeta->children[3]->value)->getU32();

    loc.offset = byteCounter;

    VarType type =
        VarTypeForComponentType((ComponentType)cast<Constant>(sigMeta->children[2]->value)->getU32());

    loc.scalarElemSize = VarTypeByteSize(type);
    loc.rowCount = cast<Constant>(sigMeta->children[6]->value)->getU32();
    loc.colCount = cast<Constant>(sigMeta->children[7]->value)->getU32();

    // move position to the front when storing, if semantic 3 (position, guaranteed to be per-vertex
    // by definition) isn't at index 0, we shuffle up everything we've added so far by 16 bytes and
    // add position here regardless of byte offset.
    if(semantic == 3 && i != 0)
    {
      RDCASSERT(loc.scalarElemSize * loc.rowCount * loc.colCount == sizeof(Vec4f),
                loc.scalarElemSize, loc.rowCount, loc.colCount);

      // shift all previous signatures up
      for(size_t prev = 0; prev < i; prev++)
        layout.sigLocations[prev].offset += sizeof(Vec4f);

      loc.offset = 0;
    }

    byteCounter += loc.scalarElemSize * loc.rowCount * loc.colCount;
  }

  layout.vertStride = AlignUp4(byteCounter);
  byteCounter = 0;

  // per primitive outputs are after output signature outputs
  for(size_t i = 0; primOutSig && i < primOutSig->children.size(); i++)
  {
    OutDXILSigLocation &loc = layout.sigLocations[firstPrimOutput + i];

    Metadata *sigMeta = primOutSig->children[i];

    loc.offset = byteCounter;

    VarType type =
        VarTypeForComponentType((ComponentType)cast<Constant>(sigMeta->children[2]->value)->getU32());

    loc.scalarElemSize = VarTypeByteSize(type);
    loc.rowCount = cast<Constant>(sigMeta->children[6]->value)->getU32();
    loc.colCount = cast<Constant>(sigMeta->children[7]->value)->getU32();

    byteCounter += loc.scalarElemSize * loc.rowCount * loc.colCount;
  }

  layout.primStride = AlignUp4(byteCounter);

  for(size_t i = 0; i < layout.sigLocations.size(); i++)
  {
    // prim/vert counts
    layout.sigLocations[i].offset += 32;

    // indices
    layout.sigLocations[i].offset +=
        AlignUp16(layout.primArrayLength * layout.indexCountPerPrim * (uint32_t)sizeof(uint32_t));

    if(i >= firstPrimOutput)
      layout.sigLocations[i].offset += layout.vertArrayLength * layout.vertStride;
  }

  // meshlet data begins with real and fake meshlet size (prim/vert count)
  layout.meshletByteSize = 32;
  const uint32_t idxDataOffset = layout.meshletByteSize;

  // then comes indices
  layout.meshletByteSize +=
      (uint32_t)AlignUp16(layout.primArrayLength * layout.indexCountPerPrim * sizeof(uint32_t));

  // after that per-vertex data
  layout.meshletByteSize += layout.vertArrayLength * layout.vertStride;

  // and finally per-primitive data
  layout.meshletByteSize += layout.primArrayLength * layout.primStride;

  // create our handle first thing
  Constant *annotateConstant = NULL;
  Instruction *handle = NULL;
  size_t prelimInst = 0;
  if(createHandle)
  {
    RDCASSERT(!isShaderModel6_6OrAbove);
    handle = editor.InsertInstruction(
        f, prelimInst++,
        editor.CreateInstruction(createHandle, DXOp::createHandle,
                                 {
                                     // kind = UAV
                                     editor.CreateConstant((uint8_t)HandleKind::UAV),
                                     // ID/slot
                                     editor.CreateConstant(regSlot),
                                     // register
                                     editor.CreateConstant(0U),
                                     // non-uniform
                                     editor.CreateConstant(false),
                                 }));
  }
  else if(createHandleFromBinding)
  {
    RDCASSERT(isShaderModel6_6OrAbove);
    const Type *resBindType = editor.CreateNamedStructType("dx.types.ResBind", {});
    Constant *resBindConstant =
        editor.CreateConstant(resBindType, {
                                               // Lower id bound
                                               editor.CreateConstant(0U),
                                               // Upper id bound
                                               editor.CreateConstant(0U),
                                               // Space ID
                                               editor.CreateConstant(space),
                                               // kind = UAV
                                               editor.CreateConstant((uint8_t)HandleKind::UAV),
                                           });

    Instruction *unannotatedHandle = editor.InsertInstruction(
        f, prelimInst++,
        editor.CreateInstruction(createHandleFromBinding, DXOp::createHandleFromBinding,
                                 {
                                     // resBind
                                     resBindConstant,
                                     // ID/slot
                                     editor.CreateConstant(0U),
                                     // non-uniform
                                     editor.CreateConstant(false),
                                 }));

    annotateConstant = editor.CreateConstant(
        editor.CreateNamedStructType("dx.types.ResourceProperties", {}),
        {
            // IsUav : (1 << 12)
            editor.CreateConstant(uint32_t((1 << 12) | (uint32_t)ResourceKind::RawBuffer)),
            //
            editor.CreateConstant(0U),
        });

    handle = editor.InsertInstruction(f, prelimInst++,
                                      editor.CreateInstruction(annotateHandle, DXOp::annotateHandle,
                                                               {
                                                                   // Resource handle
                                                                   unannotatedHandle,
                                                                   // Resource properties
                                                                   annotateConstant,
                                                               }));
  }

  RDCASSERT(handle);

  // now calculate our offset
  Constant *i32_0 = editor.CreateConstant(0U);
  Constant *i32_1 = editor.CreateConstant(1U);
  Constant *i32_2 = editor.CreateConstant(2U);
  Constant *i32_4 = editor.CreateConstant(4U);

  Instruction *baseOffset = NULL;

  Instruction *groupX = NULL, *groupY = NULL, *groupZ = NULL;

  {
    // get our output location from group ID
    groupX = editor.InsertInstruction(f, prelimInst++,
                                      editor.CreateInstruction(groupId, DXOp::groupId, {i32_0}));
    groupY = editor.InsertInstruction(f, prelimInst++,
                                      editor.CreateInstruction(groupId, DXOp::groupId, {i32_1}));
    groupZ = editor.InsertInstruction(f, prelimInst++,
                                      editor.CreateInstruction(groupId, DXOp::groupId, {i32_2}));
  }

  // get the flat thread ID for comparisons
  Instruction *flatId = editor.InsertInstruction(
      f, prelimInst++,
      editor.CreateInstruction(flattenedThreadIdInGroup, DXOp::flattenedThreadIdInGroup, {}));

  Value *dimX = NULL, *dimY = NULL;
  Instruction *dispatchBaseMeshletIdx = NULL;

  if(readAmpOffset)
  {
    // reading the payload has no dependencies but can only happen once per shader. If there was a
    // load before we search for it and bring it to the front here so we can use it ourselves. The
    // llvm value-referencing will continue to work as normal since the pointer remains the same
    Instruction *payloadLoad = NULL;
    for(size_t i = 0; i < f->instructions.size(); i++)
    {
      const Instruction &inst = *f->instructions[i];
      if(inst.op == Operation::Call && inst.getFuncCall()->name == getMeshPayload->name)
      {
        payloadLoad = editor.InsertInstruction(f, prelimInst++, f->instructions.takeAt(i));
        break;
      }
    }

    // if there wasn't one before (because we added the payload, or it was unused) we can just add our own
    if(!payloadLoad)
      payloadLoad = editor.InsertInstruction(
          f, prelimInst++, editor.CreateInstruction(getMeshPayload, DXOp::getMeshPayload, {}));

    Type *i32ptr = editor.CreatePointerType(i32, Type::PointerAddrSpace::Default);

    // .x = x dimension
    Instruction *dimXPtr = editor.InsertInstruction(
        f, prelimInst++,
        editor.CreateInstruction(
            Operation::GetElementPtr, i32ptr,
            {payloadLoad, i32_0, editor.CreateConstant(uint32_t(payloadType->members.size() - 4))}));
    // .y = y dimension
    Instruction *dimYPtr = editor.InsertInstruction(
        f, prelimInst++,
        editor.CreateInstruction(
            Operation::GetElementPtr, i32ptr,
            {payloadLoad, i32_0, editor.CreateConstant(uint32_t(payloadType->members.size() - 3))}));
    // .w = offset for this set of mesh groups
    Instruction *offsetPtr = editor.InsertInstruction(
        f, prelimInst++,
        editor.CreateInstruction(
            Operation::GetElementPtr, i32ptr,
            {payloadLoad, i32_0, editor.CreateConstant(uint32_t(payloadType->members.size() - 1))}));

    Instruction *dimXLoad = editor.InsertInstruction(
        f, prelimInst++, editor.CreateInstruction(Operation::Load, i32, {dimXPtr}));
    dimXLoad->align = 4;
    dimX = dimXLoad;

    Instruction *dimYLoad = editor.InsertInstruction(
        f, prelimInst++, editor.CreateInstruction(Operation::Load, i32, {dimYPtr}));
    dimYLoad->align = 4;
    dimY = dimYLoad;

    dispatchBaseMeshletIdx = editor.InsertInstruction(
        f, prelimInst++, editor.CreateInstruction(Operation::Load, i32, {offsetPtr}));
    dispatchBaseMeshletIdx->align = 4;
  }
  else
  {
    dimX = editor.CreateConstant(dispatchDim[0]);
    dimY = editor.CreateConstant(dispatchDim[1]);
  }

  {
    Instruction *dimXY = editor.InsertInstruction(
        f, prelimInst++, editor.CreateInstruction(Operation::Mul, i32, {dimX, dimY}));

    // linearise to slot based on the number of dispatches
    Instruction *groupYMul = editor.InsertInstruction(
        f, prelimInst++, editor.CreateInstruction(Operation::Mul, i32, {groupY, dimX}));
    Instruction *groupZMul = editor.InsertInstruction(
        f, prelimInst++, editor.CreateInstruction(Operation::Mul, i32, {groupZ, dimXY}));
    Instruction *groupYZAdd = editor.InsertInstruction(
        f, prelimInst++, editor.CreateInstruction(Operation::Add, i32, {groupYMul, groupZMul}));
    Instruction *flatIndex = editor.InsertInstruction(
        f, prelimInst++, editor.CreateInstruction(Operation::Add, i32, {groupX, groupYZAdd}));

    if(dispatchBaseMeshletIdx)
    {
      flatIndex = editor.InsertInstruction(
          f, prelimInst++,
          editor.CreateInstruction(Operation::Add, i32, {flatIndex, dispatchBaseMeshletIdx}));
    }

    baseOffset = editor.InsertInstruction(
        f, prelimInst++,
        editor.CreateInstruction(Operation::Mul, i32,
                                 {flatIndex, editor.CreateConstant(layout.meshletByteSize)}));
  }

  Constant *threadZeroCountOffset = i32_0;
  Constant *threadOtherCountOffset = editor.CreateConstant(uint32_t(16U));

  Constant *indexStride =
      editor.CreateConstant(uint32_t(layout.indexCountPerPrim * sizeof(uint32_t)));

  for(size_t i = 0; i < f->instructions.size(); i++)
  {
    const Instruction &inst = *f->instructions[i];
    if(inst.op == Operation::Call && inst.getFuncCall()->name == setMeshOutputCounts->name)
    {
      Instruction *threadIsZero = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::IEqual, i1, {flatId, i32_0}));

      // to avoid messing up phi nodes in the application where this is called, we do this
      // branchless by either writing to offset 0 (for threadIndex == 0) or offset 16 (for
      // threadIndex > 0). Then we can ignore the second one
      Instruction *byteOffset = editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(Operation::Select, i32,
                                   {threadZeroCountOffset, threadOtherCountOffset, threadIsZero}));

      Instruction *writeOffset = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::Add, i32, {baseOffset, byteOffset}));

      const Function *rawBufferStore = editor.DeclareFunction(
          "dx.op.rawBufferStore.i32", voidType,
          {i32, handleType, i32, i32, i32, i32, i32, i32, i8, i32}, Attribute::NoUnwind);

      editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(
              rawBufferStore, DXOp::rawBufferStore,
              {handle, writeOffset, editor.CreateUndef(i32), inst.args[1], editor.CreateUndef(i32),
               editor.CreateUndef(i32), editor.CreateUndef(i32),
               editor.CreateConstant((uint8_t)0x1), i32_4}));

      writeOffset = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::Add, i32, {writeOffset, i32_4}));

      editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(
              rawBufferStore, DXOp::rawBufferStore,
              {handle, writeOffset, editor.CreateUndef(i32), inst.args[2], editor.CreateUndef(i32),
               editor.CreateUndef(i32), editor.CreateUndef(i32),
               editor.CreateConstant((uint8_t)0x1), i32_4}));

      // disable the actual output
      f->instructions[i]->args[1] = i32_0;
      f->instructions[i]->args[2] = i32_0;
    }
    else if(inst.op == Operation::Call &&
            inst.getFuncCall()->name.beginsWith("dx.op.storeVertexOutput"))
    {
      uint32_t sigId = cast<Constant>(inst.args[1])->getU32();
      Value *row = inst.args[2];
      Value *col = inst.args[3];
      Value *value = inst.args[4];
      Value *vert = inst.args[5];

      OutDXILSigLocation &loc = layout.sigLocations[sigId];

      Instruction *colByteOffset = NULL;

      // col is i8, but DXIL doesn't support i8 as values (sigh...). So if that value is a constant
      // (currently must be true) then we re-create it as u32. We handle the case where it's not a
      // constant in future perhaps
      Constant *colConst = cast<Constant>(col);
      if(colConst)
      {
        colByteOffset = editor.InsertInstruction(
            f, i++,
            editor.CreateInstruction(Operation::Mul, i32,
                                     {editor.CreateConstant(colConst->getU32()),
                                      editor.CreateConstant(loc.scalarElemSize)}));
      }
      else
      {
        colByteOffset = editor.InsertInstruction(
            f, i++,
            editor.CreateInstruction(Operation::Mul, i8,
                                     {col, editor.CreateConstant(uint8_t(loc.scalarElemSize))}));

        colByteOffset =
            editor.InsertInstruction(f, i++, editor.CreateInstruction(Operation::ZExt, i32, {col}));
      }

      Instruction *elemByteOffset = colByteOffset;

      if(loc.rowCount > 1)
      {
        uint32_t rowStride = loc.scalarElemSize * loc.colCount;

        Instruction *rowOffset = editor.InsertInstruction(
            f, i++,
            editor.CreateInstruction(Operation::Mul, i32, {row, editor.CreateConstant(rowStride)}));

        elemByteOffset = editor.InsertInstruction(
            f, i++, editor.CreateInstruction(Operation::Add, i32, {rowOffset, colByteOffset}));
      }

      Instruction *vertexOffset = editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(Operation::Mul, i32,
                                   {vert, editor.CreateConstant(layout.vertStride)}));

      // base + sig indexed offset + vertex indexed offset + elem offset

      Instruction *writeOffset = editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(Operation::Add, i32,
                                   {baseOffset, editor.CreateConstant(loc.offset)}));

      writeOffset = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::Add, i32, {writeOffset, vertexOffset}));

      writeOffset = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::Add, i32, {writeOffset, elemByteOffset}));

      rdcstr suffix = makeBufferLoadStoreSuffix(value->type);

      const Function *rawBufferStore = editor.DeclareFunction(
          "dx.op.rawBufferStore." + suffix, voidType,
          {i32, handleType, i32, i32, value->type, value->type, value->type, value->type, i8, i32},
          Attribute::NoUnwind);

      editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(
              rawBufferStore, DXOp::rawBufferStore,
              {handle, writeOffset, editor.CreateUndef(i32), value, editor.CreateUndef(value->type),
               editor.CreateUndef(value->type), editor.CreateUndef(value->type),
               editor.CreateConstant((uint8_t)0x1), i32_4}));
    }
    else if(inst.op == Operation::Call &&
            inst.getFuncCall()->name.beginsWith("dx.op.storePrimitiveOutput"))
    {
      uint32_t sigId = cast<Constant>(inst.args[1])->getU32();
      Value *row = inst.args[2];
      Value *col = inst.args[3];
      Value *value = inst.args[4];
      Value *prim = inst.args[5];

      OutDXILSigLocation &loc = layout.sigLocations[firstPrimOutput + sigId];

      Instruction *colByteOffset = NULL;

      // col is i8, but DXIL doesn't support i8 as values (sigh...). So if that value is a constant
      // (currently must be true) then we re-create it as u32. We handle the case where it's not a
      // constant in future perhaps
      Constant *colConst = cast<Constant>(col);
      if(colConst)
      {
        colByteOffset = editor.InsertInstruction(
            f, i++,
            editor.CreateInstruction(Operation::Mul, i32,
                                     {editor.CreateConstant(colConst->getU32()),
                                      editor.CreateConstant(loc.scalarElemSize)}));
      }
      else
      {
        colByteOffset = editor.InsertInstruction(
            f, i++,
            editor.CreateInstruction(Operation::Mul, i8,
                                     {col, editor.CreateConstant(uint8_t(loc.scalarElemSize))}));

        colByteOffset =
            editor.InsertInstruction(f, i++, editor.CreateInstruction(Operation::ZExt, i32, {col}));
      }

      Instruction *elemByteOffset = colByteOffset;

      if(loc.rowCount > 1)
      {
        uint32_t rowStride = loc.scalarElemSize * loc.colCount;

        Instruction *rowOffset = editor.InsertInstruction(
            f, i++,
            editor.CreateInstruction(Operation::Mul, i32, {row, editor.CreateConstant(rowStride)}));

        elemByteOffset = editor.InsertInstruction(
            f, i++, editor.CreateInstruction(Operation::Add, i32, {rowOffset, colByteOffset}));
      }

      Instruction *primOffset = editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(Operation::Mul, i32,
                                   {prim, editor.CreateConstant(layout.primStride)}));

      // base + sig indexed offset + vertex indexed offset + elem offset

      Instruction *writeOffset = editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(Operation::Add, i32,
                                   {baseOffset, editor.CreateConstant(loc.offset)}));

      writeOffset = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::Add, i32, {writeOffset, primOffset}));

      writeOffset = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::Add, i32, {writeOffset, elemByteOffset}));

      rdcstr suffix = makeBufferLoadStoreSuffix(value->type);

      const Function *rawBufferStore = editor.DeclareFunction(
          "dx.op.rawBufferStore." + suffix, voidType,
          {i32, handleType, i32, i32, value->type, value->type, value->type, value->type, i8, i32},
          Attribute::NoUnwind);

      editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(
              rawBufferStore, DXOp::rawBufferStore,
              {handle, writeOffset, editor.CreateUndef(i32), value, editor.CreateUndef(value->type),
               editor.CreateUndef(value->type), editor.CreateUndef(value->type),
               editor.CreateConstant((uint8_t)0x1), i32_4}));
    }
    else if(inst.op == Operation::Call && inst.getFuncCall()->name == emitIndices->name)
    {
      // primitive index in args[1], so multiply to get location of indices
      Instruction *byteOffset = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::Mul, i32, {inst.args[1], indexStride}));

      Instruction *writeOffset = editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(Operation::Add, i32,
                                   {baseOffset, editor.CreateConstant(idxDataOffset)}));

      writeOffset = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::Add, i32, {writeOffset, byteOffset}));

      const Function *rawBufferStore = editor.DeclareFunction(
          "dx.op.rawBufferStore.i32", voidType,
          {i32, handleType, i32, i32, i32, i32, i32, i32, i8, i32}, Attribute::NoUnwind);

      // idx0
      editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(
              rawBufferStore, DXOp::rawBufferStore,
              {handle, writeOffset, editor.CreateUndef(i32), inst.args[2], editor.CreateUndef(i32),
               editor.CreateUndef(i32), editor.CreateUndef(i32),
               editor.CreateConstant((uint8_t)0x1), i32_4}));

      // idx1
      writeOffset = editor.InsertInstruction(
          f, i++, editor.CreateInstruction(Operation::Add, i32, {writeOffset, i32_4}));

      editor.InsertInstruction(
          f, i++,
          editor.CreateInstruction(
              rawBufferStore, DXOp::rawBufferStore,
              {handle, writeOffset, editor.CreateUndef(i32), inst.args[3], editor.CreateUndef(i32),
               editor.CreateUndef(i32), editor.CreateUndef(i32),
               editor.CreateConstant((uint8_t)0x1), i32_4}));

      if(layout.indexCountPerPrim > 2)
      {
        // idx2
        writeOffset = editor.InsertInstruction(
            f, i++, editor.CreateInstruction(Operation::Add, i32, {writeOffset, i32_4}));

        editor.InsertInstruction(
            f, i++,
            editor.CreateInstruction(
                rawBufferStore, DXOp::rawBufferStore,
                {handle, writeOffset, editor.CreateUndef(i32), inst.args[4],
                 editor.CreateUndef(i32), editor.CreateUndef(i32), editor.CreateUndef(i32),
                 editor.CreateConstant((uint8_t)0x1), i32_4}));
      }
    }
  }
}