ShaderVariable Processor::EvaluateConstant()

in renderdoc/driver/shaders/spirv/spirv_processor.cpp [1096:1695]


ShaderVariable Processor::EvaluateConstant(Id constID, const rdcarray<SpecConstant> &specInfo) const
{
  auto it = constants.find(constID);

  if(it == constants.end())
  {
    RDCERR("Lookup of unknown constant %u", constID.value());
    return ShaderVariable("unknown", 0, 0, 0, 0);
  }

  auto specopit = specOps.find(constID);

  if(specopit != specOps.end())
  {
    const SpecOp &specop = specopit->second;
    ShaderVariable ret = {};

    const DataType &retType = dataTypes[specop.type];

    ret.type = retType.scalar().Type();

    if(specop.params.empty())
    {
      RDCERR("Expected paramaters for SpecConstantOp %s", ToStr(specop.op).c_str());
      return ret;
    }

    // these instructions have special rules, so handle them manually
    if(specop.op == Op::Select)
    {
      // evaluate the parameters
      rdcarray<ShaderVariable> params;
      for(size_t i = 0; i < specop.params.size(); i++)
        params.push_back(EvaluateConstant(specop.params[i], specInfo));

      if(params.size() != 3)
      {
        RDCERR("Expected 3 paramaters for SpecConstantOp Select, got %zu", params.size());
        return ret;
      }

      // "If Condition is a scalar and true, the result is Object 1. If Condition is a scalar and
      // false, the result is Object 2."
      if(params[0].columns == 1)
        return params[0].value.u32v[0] ? params[1] : params[2];

      // "If Condition is a vector, Result Type must be a vector with the same number of components
      // as Condition and the result is a [component-wise] mix of Object 1 and Object 2."
      ret = params[1];
      ret.name = "derived";
      for(size_t i = 0; i < params[0].columns; i++)
      {
        if(retType.scalar().width == 64)
          ret.value.u64v[i] =
              params[0].value.u32v[i] ? params[1].value.u64v[i] : params[2].value.u64v[i];
        else if(retType.scalar().width == 32)
          ret.value.u32v[i] =
              params[0].value.u32v[i] ? params[1].value.u32v[i] : params[2].value.u32v[i];
        else if(retType.scalar().width == 16)
          ret.value.u16v[i] =
              params[0].value.u32v[i] ? params[1].value.u16v[i] : params[2].value.u16v[i];
        else
          ret.value.u8v[i] = params[0].value.u8v[i] ? params[1].value.u8v[i] : params[2].value.u8v[i];
      }
    }
    else if(specop.op == Op::CompositeExtract)
    {
      ShaderVariable composite = EvaluateConstant(specop.params[0], specInfo);
      // the remaining parameters are actually indices
      rdcarray<uint32_t> indices;
      for(size_t i = 1; i < specop.params.size(); i++)
        indices.push_back(specop.params[i].value());

      ret = composite;
      ret.name = "derived";

      RDCEraseEl(ret.value);

      if(composite.rows > 1)
      {
        ret.rows = 1;

        if(indices.size() == 1)
        {
          // matrix returning a vector
          uint32_t row = indices[0];

          for(uint32_t c = 0; c < ret.columns; c++)
          {
            if(retType.scalar().width == 64)
              ret.value.u64v[c] = composite.value.u64v[row * composite.columns + c];
            else if(retType.scalar().width == 32)
              ret.value.u32v[c] = composite.value.u32v[row * composite.columns + c];
            else if(retType.scalar().width == 16)
              ret.value.u16v[c] = composite.value.u16v[row * composite.columns + c];
            else
              ret.value.u8v[c] = composite.value.u8v[row * composite.columns + c];
          }
        }
        else if(indices.size() == 2)
        {
          // matrix returning a scalar
          uint32_t row = indices[0];
          uint32_t col = indices[1];

          if(retType.scalar().width == 64)
            ret.value.u64v[0] = composite.value.u64v[row * composite.columns + col];
          else if(retType.scalar().width == 32)
            ret.value.u32v[0] = composite.value.u32v[row * composite.columns + col];
          else if(retType.scalar().width == 16)
            ret.value.u16v[0] = composite.value.u16v[row * composite.columns + col];
          else
            ret.value.u8v[0] = composite.value.u8v[row * composite.columns + col];
        }
        else
        {
          RDCERR("Unexpected number of indices %zu to SpecConstantOp CompositeInsert",
                 indices.size());
        }
      }
      else
      {
        ret.columns = 1;

        if(indices.size() == 1)
        {
          uint32_t col = indices[0];

          // vector returning a scalar
          if(retType.scalar().width == 64)
            ret.value.u64v[0] = composite.value.u64v[col];
          else if(retType.scalar().width == 32)
            ret.value.u32v[0] = composite.value.u32v[col];
          else if(retType.scalar().width == 16)
            ret.value.u16v[0] = composite.value.u16v[col];
          else
            ret.value.u8v[0] = composite.value.u8v[col];
        }
        else
        {
          RDCERR("Unexpected number of indices %zu to SpecConstantOp CompositeInsert",
                 indices.size());
        }
      }

      return ret;
    }
    else if(specop.op == Op::CompositeInsert)
    {
      if(specop.params.size() < 3)
      {
        RDCERR("Expected at least 3 paramaters for SpecConstantOp CompositeInsert, got %zu",
               specop.params.size());
        return ret;
      }

      ShaderVariable object = EvaluateConstant(specop.params[0], specInfo);
      ShaderVariable composite = EvaluateConstant(specop.params[1], specInfo);
      // the remaining parameters are actually indices
      rdcarray<uint32_t> indices;
      for(size_t i = 2; i < specop.params.size(); i++)
        indices.push_back(specop.params[i].value());

      composite.name = "derived";

      if(composite.rows > 1)
      {
        if(indices.size() == 1)
        {
          // matrix inserting a vector
          uint32_t row = indices[0];

          for(uint32_t c = 0; c < ret.columns; c++)
          {
            if(retType.scalar().width == 64)
              composite.value.u64v[row * composite.columns + c] = object.value.u64v[c];
            else if(retType.scalar().width == 32)
              composite.value.u32v[row * composite.columns + c] = object.value.u32v[c];
            else if(retType.scalar().width == 16)
              composite.value.u16v[row * composite.columns + c] = object.value.u16v[c];
            else if(retType.scalar().width == 8)
              composite.value.u8v[row * composite.columns + c] = object.value.u8v[c];
          }
        }
        else if(indices.size() == 2)
        {
          // matrix inserting a scalar
          uint32_t row = indices[0];
          uint32_t col = indices[1];

          if(retType.scalar().width == 64)
            composite.value.u64v[row * composite.columns + col] = object.value.u64v[0];
          else if(retType.scalar().width == 32)
            composite.value.u32v[row * composite.columns + col] = object.value.u32v[0];
          else if(retType.scalar().width == 16)
            composite.value.u16v[row * composite.columns + col] = object.value.u16v[0];
          else
            composite.value.u8v[row * composite.columns + col] = object.value.u8v[0];
        }
        else
        {
          RDCERR("Unexpected number of indices %zu to SpecConstantOp CompositeInsert",
                 indices.size());
        }
      }
      else
      {
        if(indices.size() == 1)
        {
          // vector inserting a scalar
          if(retType.scalar().width == 64)
            composite.value.u64v[indices[0]] = object.value.u64v[0];
          else if(retType.scalar().width == 32)
            composite.value.u32v[indices[0]] = object.value.u32v[0];
          else if(retType.scalar().width == 16)
            composite.value.u16v[indices[0]] = object.value.u16v[0];
          else
            composite.value.u8v[indices[0]] = object.value.u8v[0];
        }
        else
        {
          RDCERR("Unexpected number of indices %zu to SpecConstantOp CompositeInsert",
                 indices.size());
        }
      }

      return composite;
    }
    else if(specop.op == Op::VectorShuffle)
    {
      if(specop.params.size() < 3)
      {
        RDCERR("Expected at least 3 paramaters for SpecConstantOp VectorShuffle, got %zu",
               specop.params.size());
        return ret;
      }

      ShaderVariable vec1 = EvaluateConstant(specop.params[0], specInfo);
      ShaderVariable vec2 = EvaluateConstant(specop.params[1], specInfo);
      // the remaining parameters are actually indices
      rdcarray<uint32_t> indices;
      for(size_t i = 2; i < specop.params.size(); i++)
        indices.push_back(specop.params[i].value());

      ret = ShaderVariable("derived", 0, 0, 0, 0);
      ret.type = retType.scalar().Type();
      ret.columns = (uint8_t)indices.size();

      for(size_t i = 0; i < indices.size(); i++)
      {
        uint32_t idx = indices[i];
        if(idx < vec1.columns)
        {
          if(retType.scalar().width == 64)
            ret.value.u64v[i] = vec1.value.u64v[idx];
          else if(retType.scalar().width == 32)
            ret.value.u32v[i] = vec1.value.u32v[idx];
          else if(retType.scalar().width == 16)
            ret.value.u16v[i] = vec1.value.u16v[idx];
          else
            ret.value.u8v[i] = vec1.value.u8v[idx];
        }
        else
        {
          idx -= vec1.columns;

          if(retType.scalar().width == 64)
            ret.value.u64v[i] = vec2.value.u64v[idx];
          else if(retType.scalar().width == 32)
            ret.value.u32v[i] = vec2.value.u32v[idx];
          else if(retType.scalar().width == 16)
            ret.value.u16v[i] = vec2.value.u16v[idx];
          else
            ret.value.u8v[i] = vec2.value.u8v[idx];
        }
      }

      return ret;
    }
    else if(specop.op == Op::UConvert || specop.op == Op::SConvert || specop.op == Op::FConvert)
    {
      ShaderVariable param = EvaluateConstant(specop.params[0], specInfo);
      ret = param;

      ret.name = "converted";
      RDCEraseEl(ret.value);
      ret.type = retType.scalar().Type();

      for(uint8_t i = 0; i < param.columns; i++)
      {
        if(specop.op == Op::UConvert)
        {
          uint64_t x = 0;

#undef _IMPL
#define _IMPL(I, S, U) x = comp<U>(param, i);
          IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, param.type);

#undef _IMPL
#define _IMPL(I, S, U) comp<U>(ret, i) = (U)x;
          IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, ret.type);
        }
        else if(specop.op == Op::SConvert)
        {
          int64_t x = 0;

#undef _IMPL
#define _IMPL(I, S, U) x = comp<S>(param, i);
          IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, param.type);

#undef _IMPL
#define _IMPL(I, S, U) comp<S>(ret, i) = (S)x;
          IMPL_FOR_INT_TYPES_FOR_TYPE(_IMPL, ret.type);
        }
        else if(specop.op == Op::FConvert)
        {
          double x = 0.0;

#undef _IMPL
#define _IMPL(T) x = comp<T>(param, i);
          IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, param.type);

#undef _IMPL
#define _IMPL(T) comp<T>(ret, i) = (T)x;
          // IMPL_FOR_FLOAT_TYPES_FOR_TYPE(_IMPL, ret.type);

          if(ret.type == VarType::Float)
            comp<float>(ret, i) = (float)x;
          else if(ret.type == VarType::Half)
            comp<half_float::half>(ret, i) = (float)x;
          else if(ret.type == VarType::Double)
            comp<double>(ret, i) = (double)x;
        }
      }

      return ret;
    }

    // evaluate the parameters
    rdcarray<ShaderVariable> params;
    for(size_t i = 0; i < specop.params.size(); i++)
      params.push_back(EvaluateConstant(specop.params[i], specInfo));

    // all other operations are component-wise on vectors or scalars. Check that rows are all 1 and
    // all cols are identical
    uint8_t cols = params[0].columns;
    for(size_t i = 0; i < params.size(); i++)
    {
      RDCASSERT(cols == params[i].columns, i, params[i].columns);
      RDCASSERT(params[i].rows == 1, i, params[i].rows);
      cols = RDCMIN(cols, params[i].columns);
    }

    // check number of parameters
    switch(specop.op)
    {
      case Op::SNegate:
      case Op::Not:
      case Op::LogicalNot:
        if(params.size() != 1)
        {
          RDCERR("Expected 1 parameter for SpecConstantOp %s, got %zu", ToStr(specop.op).c_str(),
                 params.size());
          return ret;
        }
        break;
      case Op::IAdd:
      case Op::ISub:
      case Op::IMul:
      case Op::UDiv:
      case Op::SDiv:
      case Op::UMod:
      case Op::SRem:
      case Op::SMod:
      case Op::ShiftRightLogical:
      case Op::ShiftRightArithmetic:
      case Op::ShiftLeftLogical:
      case Op::BitwiseOr:
      case Op::BitwiseXor:
      case Op::BitwiseAnd:
      case Op::LogicalOr:
      case Op::LogicalAnd:
      case Op::LogicalEqual:
      case Op::LogicalNotEqual:
      case Op::IEqual:
      case Op::INotEqual:
      case Op::ULessThan:
      case Op::SLessThan:
      case Op::UGreaterThan:
      case Op::SGreaterThan:
      case Op::ULessThanEqual:
      case Op::SLessThanEqual:
      case Op::UGreaterThanEqual:
      case Op::SGreaterThanEqual:
        if(params.size() != 2)
        {
          RDCERR("Expected 2 paramaters for SpecConstantOp %s, got %zu", ToStr(specop.op).c_str(),
                 params.size());
          return ret;
        }
        break;
      default:
        RDCERR("Unhandled SpecConstantOp:: operation %s", ToStr(specop.op).c_str());
        return ret;
    }

    ret = params[0];
    ret.name = "derived";

    for(uint32_t col = 0; col < cols; col++)
    {
      ShaderValue a, b;

      bool signedness = retType.scalar().signedness;

      // upcast parameters to 64-bit width to simplify applying operations
      for(size_t p = 0; p < params.size() && p < 2; p++)
      {
        const DataType &paramType = dataTypes[idTypes[specop.params[p]]];

        ShaderValue &val = (p == 0) ? a : b;

        if(paramType.scalar().type == Op::TypeFloat)
        {
          if(paramType.scalar().width == 64)
            val.f64v[0] = params[p].value.f64v[col];
          else if(paramType.scalar().width == 32)
            val.f64v[0] = params[p].value.f32v[col];
          else
            val.f64v[0] = (float)params[p].value.f16v[col];
        }
        else
        {
          if(paramType.scalar().signedness)
          {
            if(paramType.scalar().width == 64)
              val.s64v[0] = params[p].value.s64v[col];
            else if(paramType.scalar().width == 32)
              val.s64v[0] = params[p].value.s32v[col];
            else if(paramType.scalar().width == 16)
              val.s64v[0] = params[p].value.s16v[col];
            else
              val.s64v[0] = params[p].value.s8v[col];
          }
          else
          {
            if(paramType.scalar().width == 64)
              val.u64v[0] = params[p].value.u64v[col];
            else if(paramType.scalar().width == 32)
              val.u64v[0] = params[p].value.u32v[col];
            else if(paramType.scalar().width == 16)
              val.u64v[0] = params[p].value.u16v[col];
            else
              val.u64v[0] = params[p].value.u8v[col];
          }
        }
      }

      switch(specop.op)
      {
        case Op::SNegate: a.s64v[0] = -a.s64v[0]; break;
        case Op::Not: a.u64v[0] = ~a.u64v[0]; break;
        case Op::LogicalNot: a.u64v[0] = a.u64v[0] ? 0 : 1; break;
        case Op::IAdd:
          if(signedness)
            a.s64v[0] += b.s64v[0];
          else
            a.u64v[0] += b.u64v[0];
          break;
        case Op::ISub:
          if(signedness)
            a.s64v[0] -= b.s64v[0];
          else
            a.u64v[0] -= b.u64v[0];
          break;
        case Op::IMul:
          if(signedness)
            a.s64v[0] *= b.s64v[0];
          else
            a.u64v[0] *= b.u64v[0];
          break;
        case Op::UDiv: a.u64v[0] /= b.u64v[0]; break;
        case Op::SDiv: a.s64v[0] /= b.s64v[0]; break;
        case Op::UMod: a.u64v[0] %= b.u64v[0]; break;
        case Op::SRem:
        case Op::SMod:
        {
          int64_t result = a.s64v[0] % b.s64v[0];

          // flip sign to match given input operand

          // "the sign of r is the same as the sign of Operand 1."
          if(specop.op == Op::SRem && ((result < 0) != (a.s64v[0] < 0)))
            result = -result;
          // "the sign of r is the same as the sign of Operand 2."
          if(specop.op == Op::SMod && ((result < 0) != (b.s64v[0] < 0)))
            result = -result;

          break;
        }
        case Op::ShiftRightLogical: a.u64v[0] >>= b.u64v[0]; break;
        case Op::ShiftRightArithmetic: a.s64v[0] >>= b.s64v[0]; break;
        case Op::ShiftLeftLogical: a.u64v[0] <<= b.u64v[0]; break;
        case Op::BitwiseOr: a.u64v[0] |= b.u64v[0]; break;
        case Op::BitwiseXor: a.u64v[0] ^= b.u64v[0]; break;
        case Op::BitwiseAnd: a.u64v[0] &= b.u64v[0]; break;
        case Op::LogicalOr: a.u64v[0] = (a.u64v[0] || b.u64v[0]) ? 1 : 0; break;
        case Op::LogicalAnd: a.u64v[0] = (a.u64v[0] && b.u64v[0]) ? 1 : 0; break;
        case Op::LogicalEqual: a.u64v[0] = (a.u64v[0] == b.u64v[0]) ? 1 : 0; break;
        case Op::LogicalNotEqual: a.u64v[0] = (a.u64v[0] != b.u64v[0]) ? 1 : 0; break;
        case Op::IEqual: a.u64v[0] = (a.u64v[0] == b.u64v[0]) ? 1 : 0; break;
        case Op::INotEqual: a.u64v[0] = (a.u64v[0] != b.u64v[0]) ? 1 : 0; break;
        case Op::ULessThan: a.u64v[0] = (a.u64v[0] < b.u64v[0]) ? 1 : 0; break;
        case Op::SLessThan: a.s64v[0] = (a.s64v[0] < b.s64v[0]) ? 1 : 0; break;
        case Op::UGreaterThan: a.u64v[0] = (a.u64v[0] > b.u64v[0]) ? 1 : 0; break;
        case Op::SGreaterThan: a.s64v[0] = (a.s64v[0] > b.s64v[0]) ? 1 : 0; break;
        case Op::ULessThanEqual: a.u64v[0] = (a.u64v[0] <= b.u64v[0]) ? 1 : 0; break;
        case Op::SLessThanEqual: a.s64v[0] = (a.s64v[0] <= b.s64v[0]) ? 1 : 0; break;
        case Op::UGreaterThanEqual: a.u64v[0] = (a.u64v[0] >= b.u64v[0]) ? 1 : 0; break;
        case Op::SGreaterThanEqual: a.s64v[0] = (a.s64v[0] >= b.s64v[0]) ? 1 : 0; break;
        default: break;
      }

      // downcast back to the type required
      if(retType.scalar().type == Op::TypeFloat)
      {
        if(retType.scalar().width == 64)
          ret.value.f64v[col] = a.f64v[0];
        else if(retType.scalar().width == 32)
          ret.value.f32v[col] = (float)a.f64v[0];
        else
          ret.value.f16v[col].set((float)a.f64v[0]);
      }
      else if(signedness)
      {
        if(retType.scalar().width == 64)
          ret.value.s64v[col] = a.s64v[0];
        else if(retType.scalar().width == 32)
          ret.value.s32v[col] =
              (int32_t)RDCCLAMP(a.s64v[col], (int64_t)INT32_MIN, (int64_t)INT32_MAX);
        else if(retType.scalar().width == 16)
          ret.value.s16v[col] =
              (int16_t)RDCCLAMP(a.s64v[col], (int64_t)INT16_MIN, (int64_t)INT16_MAX);
        else
          ret.value.s8v[col] = (int8_t)RDCCLAMP(a.s64v[col], (int64_t)INT8_MIN, (int64_t)INT8_MAX);
      }
      else
      {
        if(retType.scalar().width == 64)
          ret.value.u64v[col] = a.u64v[0];
        else if(retType.scalar().width == 32)
          ret.value.u32v[col] = a.u64v[0] & 0xFFFFFFFF;
        else if(retType.scalar().width == 16)
          ret.value.u16v[col] = a.u64v[0] & 0xFFFF;
        else
          ret.value.u8v[col] = a.u64v[0] & 0xFF;
      }
    }

    return ret;
  }

  const Constant &c = it->second;

  if(decorations[c.id].flags & Decorations::HasSpecId)
  {
    for(const SpecConstant &spec : specInfo)
    {
      // if this constant is specialised, read its data instead
      if(spec.specID == decorations[c.id].specID)
      {
        ShaderVariable ret = c.value;

        // we can always just read into u64v - if the type is smaller the LSB maps nicely.
        ret.value.u64v[0] = spec.value;

        return ret;
      }
    }
  }

  if(c.op == Op::SpecConstantComposite)
  {
    ShaderVariable ret = c.value;

    rdcarray<ShaderVariable> children;

    // this is wasteful because we've probably already evaluated these constants, but we don't
    // expect a huge tree of spec constants so it's cleaner to do it here than expect the caller to
    // tidy up from its evaluated cache.
    for(size_t i = 0; i < c.children.size(); i++)
      children.push_back(EvaluateConstant(c.children[i], specInfo));

    ConstructCompositeConstant(ret, children);

    return ret;
  }

  return c.value;
}