in codegen/cuda/CUDATarget.cpp [54:679]
std::string CUDATarget::codegen(std::vector<std::string>& inputs, const Command* cmd, std::string& inpName) {
const Op* op = cmd->op;
std::stringstream ss;
switch (op->type()) {
case MNN::OpType_BinaryOp:
{
auto lhs = inputs[0], rhs = inputs[1];
auto type = static_cast<MNN::BinaryOpOperation>(op->main_as_BinaryOp()->opType());
switch (type) {
case BinaryOpOperation_ADD:
if(mVectorize) {
ss << inpName << ".x=(" << lhs << ".x+" << rhs << ".x);\n";
ss << inpName << ".y=(" << lhs << ".y+" << rhs << ".y)";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(" << lhs << ".z+" << rhs << ".z);\n";
ss << inpName << ".w=(" << lhs << ".w+" << rhs << ".w)";
}
} else {
ss << inpName << "=(" << lhs << "+" << rhs << ")";
}
break;
case BinaryOpOperation_SUB:
if(mVectorize) {
ss << inpName << ".x=(" << lhs << ".x-" << rhs << ".x);\n";
ss << inpName << ".y=(" << lhs << ".y-" << rhs << ".y)";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(" << lhs << ".z-" << rhs << ".z);\n";
ss << inpName << ".w=(" << lhs << ".w-" << rhs << ".w)";
}
} else {
ss << inpName << "=(" << lhs << "-" << rhs << ")";
}
break;
case BinaryOpOperation_MUL:
if(mVectorize) {
ss << inpName << ".x=(" << lhs << ".x*" << rhs << ".x);\n";
ss << inpName << ".y=(" << lhs << ".y*" << rhs << ".y)";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(" << lhs << ".z*" << rhs << ".z);\n";
ss << inpName << ".w=(" << lhs << ".w*" << rhs << ".w)";
}
} else {
ss << inpName << "=(" << lhs << "*" << rhs << ")";
}
break;
case BinaryOpOperation_POW:
if(mVectorize) {
ss << inpName << ".x=pow((float)" << lhs << ".x,(float)" << rhs << ".x);\n";
ss << inpName << ".y=pow((float)" << lhs << ".y,(float)" << rhs << ".y)";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=pow((float)" << lhs << ".z,(float)" << rhs << ".z);\n";
ss << inpName << ".w=pow((float)" << lhs << ".w,(float)" << rhs << ".w)";
}
} else {
ss << inpName << "=(pow((float)" << lhs << ",(float)" << rhs << "))";
}
break;
case BinaryOpOperation_DIV:
if(mVectorize) {
ss << inpName << ".x=(" << lhs << ".x/" << rhs << ".x);\n";
ss << inpName << ".y=(" << lhs << ".y/" << rhs << ".y)";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(" << lhs << ".z/" << rhs << ".z);\n";
ss << inpName << ".w=(" << lhs << ".w/" << rhs << ".w)";
}
} else {
ss << inpName << "=(" << lhs << "/" << rhs << ")";
}
break;
case BinaryOpOperation_MAXIMUM:
if(mVectorize) {
ss << inpName << ".x=fmax(" << lhs << ".x," << rhs << ".x);\n";
ss << inpName << ".y=fmax(" << lhs << ".y," << rhs << ".y)";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=fmax(" << lhs << ".z," << rhs << ".z);\n";
ss << inpName << ".w=fmax(" << lhs << ".w," << rhs << ".w)";
}
} else {
ss << inpName << "=(fmax(" << lhs << "," << rhs << "))";
}
break;
case BinaryOpOperation_MINIMUM:
if(mVectorize) {
ss << inpName << ".x=fmin(" << lhs << ".x," << rhs << ".x);\n";
ss << inpName << ".y=fmin(" << lhs << ".y," << rhs << ".y)";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=fmin(" << lhs << ".z," << rhs << ".z);\n";
ss << inpName << ".w=fmin(" << lhs << ".w," << rhs << ".w)";
}
} else {
ss << inpName << "=(fmin(" << lhs << "," << rhs << "))";
}
break;
case BinaryOpOperation_REALDIV:
if(mVectorize) {
ss << inpName << ".x=(" << lhs << ".x/" << rhs << ".x);\n";
ss << inpName << ".y=(" << lhs << ".y/" << rhs << ".y)";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(" << lhs << ".z/" << rhs << ".z);\n";
ss << inpName << ".w=(" << lhs << ".w/" << rhs << ".w)";
}
} else {
ss << inpName << "=(" << lhs << "/" << rhs << ")";//ss << "((" << rhs << ") > 0.0 ? 1.0 : ((" << rhs << ") < 0.0 ? -1.0 : 0.0) * " << lhs << "/ max(abs(" << rhs << "), 0.0000001))";
}
break;
default:
MNN_PRINT("Error: CUDA CodeGen not support Binary type:%d\n", type);
break;
}
break;
}
case MNN::OpType_Eltwise:
{
auto type = op->main_as_Eltwise()->type();
switch (type) {
case EltwiseType_SUM:
case EltwiseType_SUB:
case EltwiseType_PROD:
{
std::unordered_map<int, std::string> elemToOp = {
{EltwiseType_PROD, "*"}, {EltwiseType_SUM, "+"}, {EltwiseType_SUB, "-"}
};
if(mVectorize) {
ss << inpName << ".x=(" << inputs[0] << ".x" << elemToOp[type] << inputs[1] << ".x";
for (int i = 2; i < inputs.size(); i++) {
ss << elemToOp[type] << inputs[i] << ".x";
}
ss << ");\n";
ss << inpName << ".y=(" << inputs[0] << ".y" << elemToOp[type] << inputs[1] << ".y";
for (int i = 2; i < inputs.size(); i++) {
ss << elemToOp[type] << inputs[i] << ".y";
}
ss << ")";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(" << inputs[0] << ".z" << elemToOp[type] << inputs[1] << ".z";
for (int i = 2; i < inputs.size(); i++) {
ss << elemToOp[type] << inputs[i] << ".z";
}
ss << ");\n";
ss << inpName << ".w=(" << inputs[0] << ".w" << elemToOp[type] << inputs[1] << ".w";
for (int i = 2; i < inputs.size(); i++) {
ss << elemToOp[type] << inputs[i] << ".w";
}
ss << ")";
}
} else {
ss << inpName << "=(" << inputs[0] << elemToOp[type] << inputs[1];
for (int i = 2; i < inputs.size(); i++) {
ss << elemToOp[type] << inputs[i];
}
ss << ")";
}
break;
}
case EltwiseType_MAXIMUM:
{
if(mVectorize) {
MNN_PRINT("Error: CUDA CodeGen not support Eltwise Parallel type:%d, Please Fix it\n", type);
}
std::function<std::string(int)> fmax = [&inputs, &fmax](int d) {
if (d == inputs.size() - 1) {
return inputs[d];
}
return "fmax(" + inputs[d] + ", " + fmax(d+1) + ")";
};
ss << inpName << "=" << fmax(0);
break;
}
default:
MNN_PRINT("Error: CUDA CodeGen not support Eltwise type:%d\n", type);
break;
}
break;
}
case MNN::OpType_UnaryOp:
{
auto unary = op->main_as_UnaryOp();
auto type = unary->opType();
auto operand = inputs[0];
switch (type) {
case UnaryOpOperation_SQUARE:
if(mVectorize) {
ss << inpName << ".x=(" << operand << ".x * " << operand << ".x);\n";
ss << inpName << ".y=(" << operand << ".y * " << operand << ".y)";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(" << operand << ".z * " << operand << ".z);\n";
ss << inpName << ".w=(" << operand << ".w * " << operand << ".w)";
}
} else {
ss << inpName << "=(" << operand << " * " << operand << ")";
}
break;
case UnaryOpOperation_ERF:
if(mVectorize) {
ss << inpName << ".x=(erf(" << operand << ".x));\n";
ss << inpName << ".y=(erf(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(erf(" << operand << ".z));\n";
ss << inpName << ".w=(erf(" << operand << ".w))";
}
} else {
ss << inpName << "=(erf(" << operand << "))";
}
break;
case UnaryOpOperation_ERFC:
if(mVectorize) {
ss << inpName << ".x=(erfc(" << operand << ".x));\n";
ss << inpName << ".y=(erfc(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(erfc(" << operand << ".z));\n";
ss << inpName << ".w=(erfc(" << operand << ".w))";
}
} else {
ss << inpName << "=(erfc(" << operand << "))";
}
break;
case UnaryOpOperation_ERFINV:
if(mVectorize) {
ss << inpName << ".x=(erfinv(" << operand << ".x));\n";
ss << inpName << ".y=(erfinv(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(erfinv(" << operand << ".z));\n";
ss << inpName << ".w=(erfinv(" << operand << ".w))";
}
} else {
ss << inpName << "=(erfinv(" << operand << "))";
}
break;
case UnaryOpOperation_SQRT:
if(mVectorize) {
ss << inpName << ".x=(sqrt(" << operand << ".x));\n";
ss << inpName << ".y=(sqrt(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(sqrt(" << operand << ".z));\n";
ss << inpName << ".w=(sqrt(" << operand << ".w))";
}
} else {
ss << inpName << "=(sqrt(" << operand << "))";
}
break;
case UnaryOpOperation_RSQRT:
if(mVectorize) {
ss << inpName << ".x=(rsqrt(" << operand << ".x));\n";
ss << inpName << ".y=(rsqrt(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(rsqrt(" << operand << ".z));\n";
ss << inpName << ".w=(rsqrt(" << operand << ".w))";
}
} else {
ss << inpName << "=(rsqrt(" << operand << "))";
}
break;
case UnaryOpOperation_ABS:
if(mVectorize) {
ss << inpName << ".x=(fabs(" << operand << ".x));\n";
ss << inpName << ".y=(fabs(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(fabs(" << operand << ".z));\n";
ss << inpName << ".w=(fabs(" << operand << ".w))";
}
} else {
ss << inpName << "=(fabs(" << operand << "))";
}
break;
case UnaryOpOperation_SIN:
if(mVectorize) {
ss << inpName << ".x=(sin(" << operand << ".x));\n";
ss << inpName << ".y=(sin(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(sin(" << operand << ".z));\n";
ss << inpName << ".w=(sin(" << operand << ".w))";
}
} else {
ss << inpName << "=(sin(" << operand << "))";
}
break;
case UnaryOpOperation_COS:
if(mVectorize) {
ss << inpName << ".x=(cos(" << operand << ".x));\n";
ss << inpName << ".y=(cos(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(cos(" << operand << ".z));\n";
ss << inpName << ".w=(cos(" << operand << ".w))";
}
} else {
ss << inpName << "=(cos(" << operand << "))";
}
break;
case UnaryOpOperation_ASIN:
if(mVectorize) {
ss << inpName << ".x=(asin(" << operand << ".x));\n";
ss << inpName << ".y=(asin(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(asin(" << operand << ".z));\n";
ss << inpName << ".w=(asin(" << operand << ".w))";
}
} else {
ss << inpName << "=(asin(" << operand << "))";
}
break;
case UnaryOpOperation_ACOS:
if(mVectorize) {
ss << inpName << ".x=(acos(" << operand << ".x));\n";
ss << inpName << ".y=(acos(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(acos(" << operand << ".z));\n";
ss << inpName << ".w=(acos(" << operand << ".w))";
}
} else {
ss << inpName << "=(acos(" << operand << "))";
}
break;
case UnaryOpOperation_SIGN:
if(mVectorize) {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << ".x=((" << operand << ".x > (half)0.0) ? (half)1.0 : ((" << operand << ".x < (half)0.0) ? (half)(-1.0) : (half)0.0));\n";
ss << inpName << ".y=((" << operand << ".y > (half)0.0) ? (half)1.0 : ((" << operand << ".y < (half)0.0) ? (half)(-1.0) : (half)0.0))";
} else {
ss << inpName << ".x=((" << operand << ".x > 0.0) ? 1.0 : ((" << operand << ".x < 0.0) ? (-1.0) : 0.0));\n";
ss << inpName << ".y=((" << operand << ".y > 0.0) ? 1.0 : ((" << operand << ".y < 0.0) ? (-1.0) : 0.0))";
ss << ";\n";
ss << inpName << ".z=((" << operand << ".z > 0.0) ? 1.0 : ((" << operand << ".z < 0.0) ? (-1.0) : 0.0));\n";
ss << inpName << ".w=((" << operand << ".w > 0.0) ? 1.0 : ((" << operand << ".w < 0.0) ? (-1.0) : 0.0))";
}
} else {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << "=(" << operand << "> (half)0.0 ? (half)1.0 : (" << operand << "<(half)0.0 ? (half)(-1.0) : (half)0.0))";
} else {
ss << inpName << "=(" << operand << "> 0.0 ? 1.0 : (" << operand << "<0.0 ? (-1.0) : 0.0))";
}
}
break;
case UnaryOpOperation_EXP:
if(mVectorize) {
ss << inpName << ".x=(exp(" << operand << ".x));\n";
ss << inpName << ".y=(exp(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(exp(" << operand << ".z));\n";
ss << inpName << ".w=(exp(" << operand << ".w))";
}
} else {
ss << inpName << "=(exp(" << operand << "))";
}
break;
case UnaryOpOperation_NEG:
if(mVectorize) {
ss << inpName << ".x=(-(" << operand << ".x));\n";
ss << inpName << ".y=(-(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(-(" << operand << ".z));\n";
ss << inpName << ".w=(-(" << operand << ".w))";
}
} else {
ss << inpName << "=(-(" << operand << "))";
}
break;
case UnaryOpOperation_TAN:
if(mVectorize) {
ss << inpName << ".x=(tan(" << operand << ".x));\n";
ss << inpName << ".y=(tan(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(tan(" << operand << ".z));\n";
ss << inpName << ".w=(tan(" << operand << ".w))";
}
} else {
ss << inpName << "=(tan(" << operand << "))";
}
break;
case UnaryOpOperation_ATAN:
if(mVectorize) {
ss << inpName << ".x=(atan(" << operand << ".x));\n";
ss << inpName << ".y=(atan(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(atan(" << operand << ".z));\n";
ss << inpName << ".w=(atan(" << operand << ".w))";
}
} else {
ss << inpName << "=(atan(" << operand << "))";
}
break;
case UnaryOpOperation_CEIL:
if(mVectorize) {
ss << inpName << ".x=(ceil(" << operand << ".x));\n";
ss << inpName << ".y=(ceil(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(ceil(" << operand << ".z));\n";
ss << inpName << ".w=(ceil(" << operand << ".w))";
}
} else {
ss << inpName << "=(ceil(" << operand << "))";
}
break;
case UnaryOpOperation_LOG1P:
if(mVectorize) {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << ".x=(half)(log(1.0+(float)" << operand << ".x));\n";
ss << inpName << ".y=(half)(log(1.0+(float)" << operand << ".y))";
} else {
ss << inpName << ".x=(log(1.0+" << operand << ".x));\n";
ss << inpName << ".y=(log(1.0+" << operand << ".y))";
ss << ";\n";
ss << inpName << ".z=(log(1.0+" << operand << ".z));\n";
ss << inpName << ".w=(log(1.0+" << operand << ".w))";
}
} else {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << "=(log((half)1.0+" << operand << "))";
} else {
ss << inpName << "=(log(1.0+" << operand << "))";
}
}
break;
case UnaryOpOperation_FLOOR:
if(mVectorize) {
ss << inpName << ".x=(floor(" << operand << ".x));\n";
ss << inpName << ".y=(floor(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(floor(" << operand << ".z));\n";
ss << inpName << ".w=(floor(" << operand << ".w))";
}
} else {
ss << inpName << "=(floor(" << operand << "))";
}
break;
case UnaryOpOperation_ROUND:
if(mVectorize) {
ss << inpName << ".x=(round(" << operand << ".x));\n";
ss << inpName << ".y=(round(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(round(" << operand << ".z));\n";
ss << inpName << ".w=(round(" << operand << ".w))";
}
} else {
ss << inpName << "=(round(" << operand << "))";
}
break;
case UnaryOpOperation_SIGMOID:
if(mVectorize) {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << ".x=(half)(1.0/(1.0+(float)exp(-" << operand << ".x)));\n";
ss << inpName << ".y=(half)(1.0/(1.0+(float)exp(-" << operand << ".y)))";
} else {
ss << inpName << ".x=(1.0/(1.0+exp(-" << operand << ".x)));\n";
ss << inpName << ".y=(1.0/(1.0+exp(-" << operand << ".y)))";
ss << ";\n";
ss << inpName << ".z=(1.0/(1.0+exp(-" << operand << ".z)));\n";
ss << inpName << ".w=(1.0/(1.0+exp(-" << operand << ".w)))";
}
} else {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << "=(half)(1.0/(1.0+(float)exp(-" << operand << ")))";
} else {
ss << inpName << "=(1.0/(1.0+exp(-" << operand << ")))";
}
}
break;
case UnaryOpOperation_TANH:
if(mVectorize) {
ss << inpName << ".x=(tanh(" << operand << ".x));\n";
ss << inpName << ".y=(tanh(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(tanh(" << operand << ".z));\n";
ss << inpName << ".w=(tanh(" << operand << ".w));";
}
} else {
ss << inpName << "=(tanh(" << operand << "))";
}
break;
case UnaryOpOperation_RECIPROCAL:
if(mVectorize) {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << ".x=(half)(1.0/(float)" << operand << ".x);\n";
ss << inpName << ".y=(half)(1.0/(float)" << operand << ".y)";
} else {
ss << inpName << ".x=(1.0/" << operand << ".x);\n";
ss << inpName << ".y=(1.0/" << operand << ".y)";
ss << ";\n";
ss << inpName << ".z=(1.0/" << operand << ".z);\n";
ss << inpName << ".w=(1.0/" << operand << ".w)";
}
} else {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << "=(half)(1.0/(float)" << operand << ")";
} else {
ss << inpName << "=(1.0/" << operand << ")";
}
}
break;
case UnaryOpOperation_LOG:
if(mVectorize) {
ss << inpName << ".x=(log(" << operand << ".x));\n";
ss << inpName << ".y=(log(" << operand << ".y))";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=(log(" << operand << ".z));\n";
ss << inpName << ".w=(log(" << operand << ".w))";
}
} else {
ss << inpName << "=(log(" << operand << "))";
}
break;
case UnaryOpOperation_GELU:
if(mVectorize) {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << ".x=(half)((1.0f + tanh(0.79788458f * (0.044715f * (float)" << operand << ".x*(float)" << operand << ".x*(float)" << operand << ".x+(float)" << operand + ".x))) * (float)" << operand << ".x* 0.5f);\n";
ss << inpName << ".y=(half)((1.0f + tanh(0.79788458f * (0.044715f * (float)" << operand << ".y*(float)" << operand << ".y*(float)" << operand << ".y+(float)" << operand + ".y))) * (float)" << operand << ".y* 0.5f)";
} else {
ss << inpName << ".x=((1.0f + tanh(0.79788458f * (0.044715f * " << operand << ".x*" << operand << ".x*" << operand << ".x+" << operand + ".x))) * " << operand << ".x* 0.5f);\n";
ss << inpName << ".y=((1.0f + tanh(0.79788458f * (0.044715f * " << operand << ".y*" << operand << ".y*" << operand << ".y+" << operand + ".y))) * " << operand << ".y* 0.5f)";
ss << ";\n";
ss << inpName << ".z=((1.0f + tanh(0.79788458f * (0.044715f * " << operand << ".z*" << operand << ".z*" << operand << ".z+" << operand + ".z))) * " << operand << ".z* 0.5f);\n";
ss << inpName << ".w=((1.0f + tanh(0.79788458f * (0.044715f * " << operand << ".w*" << operand << ".w*" << operand << ".w+" << operand + ".w))) * " << operand << ".w* 0.5f)";
}
} else {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << "=(half)((1.0f + tanh(0.79788458f * (0.044715f * (float)" << operand << "*(float)" << operand << "*(float)" << operand << "+(float)" << operand + "))) * (float)" << operand << "* 0.5f)";
} else {
ss << inpName << "=((1.0f + tanh(0.79788458f * (0.044715f * " << operand << "*" << operand << "*" << operand << "+" << operand + "))) * " << operand << "* 0.5f)";
}
}
break;
case UnaryOpOperation_GELU_STANDARD:
if(mVectorize) {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << ".x=(half)((erf((float)" << operand << ".x*0.7071067932881648f)+1.f)*(float)" << operand << ".x*0.5f);\n";
ss << inpName << ".y=(half)((erf((float)" << operand << ".y*0.7071067932881648f)+1.f)*(float)" << operand << ".y*0.5f)";
} else {
ss << inpName << ".x=((erf(" << operand << ".x*0.7071067932881648f)+1.f)*" << operand << ".x*0.5f);\n";
ss << inpName << ".y=((erf(" << operand << ".y*0.7071067932881648f)+1.f)*" << operand << ".y*0.5f)";
ss << ";\n";
ss << inpName << ".z=((erf(" << operand << ".z*0.7071067932881648f)+1.f)*" << operand << ".z*0.5f);\n";
ss << inpName << ".w=((erf(" << operand << ".w*0.7071067932881648f)+1.f)*" << operand << ".w*0.5f)";
}
} else {
if(mPrecision == BackendConfig::Precision_Low) {
ss << inpName << "=(half)((erf((float)" << operand << "*0.7071067932881648f)+1.f)*(float)" << operand << "*0.5f)";
} else {
ss << inpName << "=((erf(" << operand << "*0.7071067932881648f)+1.f)*" << operand << "*0.5f)";
}
}
break;
default:
MNN_PRINT("Error: CUDA CodeGen not support Unary type:%d\n", type);
break;
}
break;
}
case MNN::OpType_ReLU6:
{
auto operand = inputs[0];
auto relu6 = op->main_as_Relu6();
float minv = relu6->minValue();
float maxv = relu6->maxValue();
if(mVectorize) {
ss << inpName << ".x=fmin(fmax(" << operand << ".x," << numval(minv) << "), " << numval(maxv) << ");\n";
ss << inpName << ".y=fmin(fmax(" << operand << ".y," << numval(minv) << "), " << numval(maxv) << ")";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=fmin(fmax(" << operand << ".z," << numval(minv) << "), " << numval(maxv) << ");\n";
ss << inpName << ".w=fmin(fmax(" << operand << ".w," << numval(minv) << "), " << numval(maxv) << ")";
}
} else {
ss << inpName << "=fmin(fmax(" << operand << "," << numval(minv) << "), " << numval(maxv) << ")";
}
break;
}
case MNN::OpType_ReLU:
{
auto operand = inputs[0];
auto relu = op->main_as_Relu();
float slope = relu->slope();
if(mVectorize) {
ss << inpName << ".x=fmax(" << operand << ".x," << numval(0) << ");\n";
ss << inpName << ".y=fmax(" << operand << ".y," << numval(0) << ")";
if(mPrecision != BackendConfig::Precision_Low) {
ss << ";\n";
ss << inpName << ".z=fmax(" << operand << ".z," << numval(0) << ");\n";
ss << inpName << ".w=fmax(" << operand << ".w," << numval(0) << ")";
}
} else {
ss << inpName << "=fmax(" << operand << "," << numval(0) << ")";
}
break;
}
case MNN::OpType_Raster:
{
auto operand = inputs[0];
ss << inpName << "=(" << operand << ")";
break;
}
default:
break;
}
return ss.str();
}