export function generateShader()

in src/webgpu/shader/validation/shader_io/util.ts [12:86]


export function generateShader({
  attribute,
  type,
  stage,
  io,
  use_struct,
}: {
  attribute: string;
  type: string;
  stage: string;
  io: string;
  use_struct: boolean;
}) {
  let code = '';

  if (attribute.includes('subgroup')) {
    code += 'enable subgroups;\n';
  }
  if (attribute.includes('clip_distances')) {
    code += 'enable clip_distances;\n';
  }

  if (use_struct) {
    // Generate a struct that wraps the entry point IO variable.
    code += 'struct S {\n';
    code += `  ${attribute} value : ${type},\n`;
    if (stage === 'vertex' && io === 'out' && !attribute.includes('builtin(position)')) {
      // Add position builtin for vertex outputs.
      code += `  @builtin(position) position : vec4<f32>,\n`;
    }
    code += '};\n\n';
  }

  if (stage !== '') {
    // Generate the entry point attributes.
    code += `@${stage}`;
    if (stage === 'compute') {
      code += ' @workgroup_size(1)';
    }
  }

  // Generate the entry point parameter and return type.
  let param = '';
  let retType = '';
  let retVal = '';
  if (io === 'in') {
    if (use_struct) {
      param = `in : S`;
    } else {
      param = `${attribute} value : ${type}`;
    }

    // Vertex shaders must always return `@builtin(position)`.
    if (stage === 'vertex') {
      retType = `-> @builtin(position) vec4<f32>`;
      retVal = `return vec4<f32>();`;
    }
  } else if (io === 'out') {
    if (use_struct) {
      retType = '-> S';
      retVal = `return S();`;
    } else {
      retType = `-> ${attribute} ${type}`;
      retVal = `return ${type}();`;
    }
  }

  code += `
    fn main(${param}) ${retType} {
      ${retVal}
    }
  `;

  return code;
}