in src/contrib/msc/core/utils.cc [391:488]
const Array<String> ExprUtils::GetInputTypes(const String& optype, size_t inputs_num,
bool as_relax) {
Array<String> input_types;
if (as_relax && (optype == "broadcast_to" || optype == "reshape")) {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("shape");
}
} else if (optype == "clip" && as_relax) {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("min");
input_types.push_back("max");
}
} else if (optype == "full" && as_relax) {
input_types.push_back("shape");
input_types.push_back("input");
} else if (optype == "strided_slice") {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("axes");
input_types.push_back("begin");
input_types.push_back("end");
input_types.push_back("strides");
}
} else if (optype == "triu") {
input_types.push_back("input");
input_types.push_back("k");
} else if (optype == "tril" || optype == "trilu") {
input_types.push_back("input");
input_types.push_back("k");
} else if (optype == "image.resize2d" && as_relax) {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("size");
}
} else if (optype == "nn.conv1d" || optype == "nn.conv2d" || optype == "nn.conv3d") {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("weight");
}
} else if (optype == "nn.batch_norm") {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("gamma");
input_types.push_back("beta");
input_types.push_back("mean");
input_types.push_back("var");
}
} else if (optype == "nn.layer_norm" || optype == "nn.group_norm") {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("gamma");
input_types.push_back("beta");
}
} else if (optype == "msc.linear") {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("weight");
}
} else if (optype == "msc.conv1d_bias" || optype == "msc.conv2d_bias") {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("weight");
input_types.push_back("bias");
}
if (as_relax && inputs_num > 3) {
input_types.push_back("expand_bias");
}
} else if (optype == "msc.linear_bias") {
input_types.push_back("input");
if (inputs_num > 1) {
input_types.push_back("weight");
input_types.push_back("bias");
}
} else if (optype == "msc.embedding" && inputs_num == 2) {
input_types.push_back("input");
input_types.push_back("weight");
} else if (optype == "msc.embedding" && inputs_num == 4) {
input_types.push_back("input");
input_types.push_back("reduce_in");
input_types.push_back("weight");
input_types.push_back("expand_out");
} else if (optype == "msc.gelu") {
input_types.push_back("input");
input_types.push_back("factor_1");
input_types.push_back("factor_2");
input_types.push_back("factor_3");
} else {
for (size_t i = 0; i < inputs_num; i++) {
input_types.push_back("input");
}
}
ICHECK_EQ(input_types.size(), inputs_num)
<< "Optype " << optype << " get input types " << input_types << " and inputs_num "
<< inputs_num << " mismatch";
return input_types;
}