private createShadeInternal()

in web/src/webgpu.ts [565:768]


  private createShadeInternal(
    finfo: FunctionInfo,
    code: string,
    asyncMode: boolean
  ): Function | Promise<Function> {
    const dispatchToDim: Array<number> = [];
    let paramWriteAccess: Array<number> = [];

    for (let i = 0; i < finfo.launch_param_tags.length; ++i) {
      const tag: string = finfo.launch_param_tags[i];
      if (tag.startsWith("blockIdx.")) {
        const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0));
        assert(target >= 0 && target < 3);
        dispatchToDim.push(target);
      } else if (tag.startsWith("threadIdx.")) {
        const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0));
        assert(target >= 0 && target < 3);
        dispatchToDim.push(target + 3);
      } else if (tag.startsWith("paramWriteAccess:")) {
        paramWriteAccess = JSON.parse(tag.substring(17));
      } else {
        throw new Error("Cannot handle thread_axis " + tag);
      }
    }


    const layoutEntries: Array<GPUBindGroupLayoutEntry> = [];
    const bufferArgIndices: Array<number> = [];
    const podArgIndices: Array<number> = [];

    for (let i = 0; i < finfo.arg_types.length; ++i) {
      const dtype = finfo.arg_types[i];
      if (dtype == "handle") {
        layoutEntries.push({
          binding: bufferArgIndices.length,
          visibility: GPUShaderStage.COMPUTE,
          buffer: {
            type: paramWriteAccess[bufferArgIndices.length] ? "storage" : "read-only-storage"
          }
        });
        bufferArgIndices.push(i);
      } else if (dtype.startsWith("int") || dtype.startsWith("uint") || dtype.startsWith("float")) {
        podArgIndices.push(i);
      } else {
        throw new Error("Cannot handle argument type " + dtype + " in WebGPU shader");
      }
    }

    assert(paramWriteAccess.length == bufferArgIndices.length);
    // POD arguments are pass in the end
    layoutEntries.push({
      binding: bufferArgIndices.length,
      visibility: GPUShaderStage.COMPUTE,
      buffer: {
        type: "uniform"
      }
    });

    const bindGroupLayout = this.device.createBindGroupLayout({
      entries: layoutEntries
    });
    const pipelineLayout = this.device.createPipelineLayout({
      bindGroupLayouts: [bindGroupLayout]
    });

    // Function to create the pipeline.
    const createShaderFunc = (pipeline: GPUComputePipeline): Function => {
      const submitShader = (...args: Array<GPUPointer | number>): void => {
        if (this.debugShaderSubmitLimit != -1 &&
          this.shaderSubmitCounter >= this.debugShaderSubmitLimit) {
          this.shaderSubmitCounter += 1;
          return;
        }

        const commandEncoder = this.device.createCommandEncoder();
        const compute = commandEncoder.beginComputePass();
        compute.setPipeline(pipeline);
        const bindGroupEntries: Array<GPUBindGroupEntry> = [];
        const numBufferOrPodArgs = bufferArgIndices.length + podArgIndices.length;

        assert(args.length == numBufferOrPodArgs + dispatchToDim.length);

        const workDim: Array<number> = [1, 1, 1, 1, 1, 1];
        for (let i = 0; i < dispatchToDim.length; ++i) {
          workDim[dispatchToDim[i]] = args[numBufferOrPodArgs + i];
        }

        // get around 65535 restriction of blockIdx.x
        if (workDim[2] != 1) {
          throw Error("WebGPU: blockIdx.z is reserved for internal use");
        }
        const packDimX = workDim[0];
        // spread thinsg out into blockIdx.z
        if (workDim[0] >= (1 << 16)) {
          let wl_x = workDim[0];
          let wl_z = workDim[2];

          while (wl_x >= (1 << 16)) {
            if (wl_x % 2 == 0) {
              wl_x = wl_x / 2;
            } else {
              // pad up
              wl_x = (wl_x + 1) / 2;
            }
            wl_z *= 2;
          }
          workDim[0] = wl_x;
          workDim[2] = wl_z;
          assert(wl_x * wl_z >= packDimX);
        }

        for (let i = 0; i < bufferArgIndices.length; ++i) {
          bindGroupEntries.push({
            binding: i,
            resource: {
              buffer: this.gpuBufferFromPtr(args[bufferArgIndices[i]])
            }
          });
        }

        // push pod buffer
        const sizeOfI32 = 4;
        const podArgBuffer = this.getPodArgsBuffer((podArgIndices.length + 1) * sizeOfI32);
        const i32View = new Int32Array(podArgIndices.length + 1);
        const u32View = new Uint32Array(i32View.buffer);
        const f32View = new Float32Array(i32View.buffer);

        for (let i = 0; i < podArgIndices.length; ++i) {
          const value = args[podArgIndices[i]];
          const dtype = finfo.arg_types[podArgIndices[i]];
          if (dtype.startsWith("int")) {
            i32View[i] = value;
          } else if (dtype.startsWith("uint")) {
            u32View[i] = value;
          } else if (dtype.startsWith("float")) {
            f32View[i] = value;
          } else {
            throw Error("Unknown pod dtype " + dtype);
          }
        }
        // always pass in dim z launching grid size in
        u32View[podArgIndices.length] = packDimX;
        this.device.queue.writeBuffer(podArgBuffer, 0, i32View.buffer);

        bindGroupEntries.push({
          binding: bufferArgIndices.length,
          resource: {
            buffer: podArgBuffer,
            size: i32View.buffer.byteLength
          }
        });

        compute.setBindGroup(0, this.device.createBindGroup({
          layout: bindGroupLayout,
          entries: bindGroupEntries
        }));

        compute.dispatchWorkgroups(workDim[0], workDim[1], workDim[2])
        compute.end()
        const command = commandEncoder.finish();
        this.device.queue.submit([command]);

        if (this.debugLogFinish) {
          const currCounter = this.shaderSubmitCounter;
          this.device.queue.onSubmittedWorkDone().then(() => {
            console.log("[" + currCounter + "][Debug] finish shader" + finfo.name);
          });
        }
        this.shaderSubmitCounter += 1;
      };
      return submitShader;
    };

    const shaderModule = this.device.createShaderModule({
      code: code,
      compilationHints: [
        {
          entryPoint: "main",
          layout: pipelineLayout
        }
      ]
    });

    if (asyncMode) {
      return this.device.createComputePipelineAsync({
        layout: pipelineLayout,
        compute: {
          module: shaderModule,
          entryPoint: finfo.name
        }
      }).then((pipeline: GPUComputePipeline) => {
        return createShaderFunc(pipeline);
      });
    } else {
      const pipeline = this.device.createComputePipeline({
        layout: pipelineLayout,
        compute: {
          module: shaderModule,
          entryPoint: finfo.name
        }
      });
      return createShaderFunc(pipeline);
    }
  }