public String getTemplate()

in src/main/java/org/apache/sysds/hops/codegen/cplan/cuda/Binary.java [30:196]


	public String getTemplate(CNodeBinary.BinType type, boolean sparseLhs, boolean sparseRhs,
		boolean scalarVector, boolean scalarInput, boolean vectorVector)
	{
		if(type == CNodeBinary.BinType.VECT_CBIND) {
			if(scalarInput)
				return "\t\tVector<T>& %TMP% = vectCbindWrite(%IN1%, %IN2%, this);\n";
			else if (!vectorVector)
				return sparseLhs ? 
					"\t\tVector<T>& %TMP% = vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%, this);\n" :
					"\t\tVector<T>& %TMP% = vectCbindWrite(%IN1%, %IN2%, %POS1%, %LEN%, this);\n";
			else //vect/vect
				return sparseLhs ?
					"\t\tVector<T>& %TMP% = vectCbindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, %LEN1%, %LEN2%, this);\n" :
					"\t\tVector<T>& %TMP% = vectCbindWrite(%IN1%, %IN2%, %POS1%, %POS2%, %LEN1%, %LEN2%, this);\n";
		}

		switch(type) {
			case ROWMAXS_VECTMULT:
				return sparseLhs ? "\t\tT %TMP% = rowMaxsVectMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
						"\t\tT %TMP% = rowMaxsVectMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
			case DOT_PRODUCT:
				return sparseLhs ? "\t\tT %TMP% = dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : "		T %TMP% = dotProduct(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%);\n";

			case VECT_MATRIXMULT:
				return sparseLhs ? "	T[] %TMP% = vectMatrixMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen, len);\n" : "		Vector<T>& %TMP% = vectMatrixMult(%IN1%, %IN2%, %POS1%, static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
			case VECT_OUTERMULT_ADD:
				return sparseLhs ? "	LibSpoofPrimitives.vectOuterMultAdd(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : sparseRhs ? "	LibSpoofPrimitives.vectOuterMultAdd(%IN1%, %IN2v%, %OUT%, %POS1%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN1%, %LEN2%);\n" : "\t\tvectOuterMultAdd(%IN1%, %IN2%, %OUT%, %POS1%, %POS2%, %POSOUT%, %LEN1%, %LEN2%);\n";

			//vector-scalar-add operations
			case VECT_MULT_ADD:
			case VECT_DIV_ADD:
			case VECT_MINUS_ADD:
			case VECT_PLUS_ADD:
			case VECT_POW_ADD:
			case VECT_XOR_ADD:
			case VECT_MIN_ADD:
			case VECT_MAX_ADD:
			case VECT_EQUAL_ADD:
			case VECT_NOTEQUAL_ADD:
			case VECT_LESS_ADD:
			case VECT_LESSEQUAL_ADD:
			case VECT_GREATER_ADD:
			case VECT_GREATEREQUAL_ADD:
			case VECT_CBIND_ADD: {
				String vectName = type.getVectorPrimitiveName();
				if(scalarVector)
					return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1%, %IN2v%, %OUT%, %IN2i%, %POS2%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS2%, %POSOUT%, %LEN%);\n";
				else
					return sparseLhs ? "\t\tvect" + vectName + "Add(%IN1v%, %IN2%, %OUT%, %IN1i%, %POS1%, %POSOUT%, alen, %LEN%);\n" : "\t\tvect" + vectName + "Add(%IN1%, %IN2%, %OUT%, %POS1%, static_cast<uint32_t>(%POSOUT%), %LEN%);\n";
			}

			//vector-scalar operations
			case VECT_MULT_SCALAR:
			case VECT_DIV_SCALAR:
			case VECT_MINUS_SCALAR:
			case VECT_PLUS_SCALAR:
			case VECT_POW_SCALAR:
			case VECT_XOR_SCALAR:
			case VECT_BITWAND_SCALAR:
			case VECT_MIN_SCALAR:
			case VECT_MAX_SCALAR:
			case VECT_EQUAL_SCALAR:
			case VECT_NOTEQUAL_SCALAR:
			case VECT_LESS_SCALAR:
			case VECT_LESSEQUAL_SCALAR:
			case VECT_GREATER_SCALAR:
			case VECT_GREATEREQUAL_SCALAR: {
				String vectName = type.getVectorPrimitiveName();
				if(scalarVector)
					return sparseRhs ? "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2v%, %IN2i%, %POS2%, alen, %LEN%);\n" : "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, %POS2%, %LEN%, this);\n";
				else
//						return sparseLhs ? "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : "	T[] %TMP% = LibSpoofPrimitives.vect" + vectName + "Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
					return sparseLhs ? "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%, this);\n" : "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, static_cast<uint32_t>(%POS1%), %LEN%, this);\n";
			}

			//vector-vector operations
			case VECT_MULT:
			case VECT_DIV:
			case VECT_MINUS:
			case VECT_PLUS:
			case VECT_XOR:
			case VECT_BITWAND:
			case VECT_BIASADD:
			case VECT_BIASMULT:
			case VECT_MIN:
			case VECT_MAX:
			case VECT_EQUAL:
			case VECT_NOTEQUAL:
			case VECT_LESS:
			case VECT_LESSEQUAL:
			case VECT_GREATER:
			case VECT_GREATEREQUAL: {
				String vectName = type.getVectorPrimitiveName();
				return sparseLhs ? "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1v%, %IN2%, %IN1i%, " +
					   "%POS1%, %POS2%, alen, %LEN%, this);\n" :
					   sparseRhs ? "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2v%, %POS1%, " +
					   "%IN2i%, %POS2%, alen, %LEN%);\n" :
					   "		Vector<T>& %TMP% = vect" + vectName + "Write(%IN1%, %IN2%, " +
					   "static_cast<uint32_t>(%POS1%), static_cast<uint32_t>(%POS2%), %LEN%, this);\n";
			}

			//scalar-scalar operations
			case MULT:
				return "		T %TMP% = %IN1% * %IN2%;\n";
			case DIV:
				return "\t\tT %TMP% = %IN1% / %IN2%;\n";
			case PLUS:
				return "\t\tT %TMP% = %IN1% + %IN2%;\n";
			case MINUS:
				return "	T %TMP% = %IN1% - %IN2%;\n";
			case MODULUS:
				return "	T %TMP% = modulus(%IN1%, %IN2%);\n";
			case INTDIV:
				return "	T %TMP% = intDiv(%IN1%, %IN2%);\n";
			case LESS:
				return "	T %TMP% = (%IN1% < %IN2%) ? 1.0 : 0.0;\n";
			case LESSEQUAL:
				return "	T %TMP% = (%IN1% <= %IN2%) ? 1.0 : 0.0;\n";
			case GREATER:
				return "	T %TMP% = (%IN1% > (%IN2% + EPSILON)) ? 1.0 : 0.0;\n";
			case GREATEREQUAL:
				return "	T %TMP% = (%IN1% >= %IN2%) ? 1.0 : 0.0;\n";
			case EQUAL:
				return "	T %TMP% = (%IN1% == %IN2%) ? 1.0 : 0.0;\n";
			case NOTEQUAL:
				return "\t\tT %TMP% = (%IN1% != %IN2%) ? 1.0 : 0.0;\n";
			case MIN:
				if(isSinglePrecision())
					return "\t\tT %TMP% = fminf(%IN1%, %IN2%);\n";
				else
					return "\t\tT %TMP% = min(%IN1%, %IN2%);\n";
			case MAX:
				if(isSinglePrecision())
					return "\t\tT %TMP% = fmaxf(%IN1%, %IN2%);\n";
				else
					return "\t\tT %TMP% = max(%IN1%, %IN2%);\n";
			case LOG:
				if(isSinglePrecision())
					return "\t\tT %TMP% = logf(%IN1%) / logf(%IN2%);\n";
				else
					return "\t\tT %TMP% = log(%IN1%) / log(%IN2%);\n";
			case LOG_NZ:
				if(isSinglePrecision())
					return "\t\tT %TMP% = (%IN1% == 0) ? 0 : logf(%IN1%) / logf(%IN2%);\n";
				else
					return "\t\tT %TMP% = (%IN1% == 0) ? 0 : log(%IN1%) / log(%IN2%);\n";
			case POW:
				if(isSinglePrecision())
					return "\t\tT %TMP% = powf(%IN1%, %IN2%);\n";
				else
					return "\t\tT %TMP% = pow(%IN1%, %IN2%);\n";
			case MINUS1_MULT:
				return "	T %TMP% = 1 - %IN1% * %IN2%;\n";
			case MINUS_NZ:
				return "	T %TMP% = (%IN1% != 0) ? %IN1% - %IN2% : 0;\n";
			case XOR:
//					return "	T %TMP% = ( (%IN1% != 0.0) != (%IN2% != 0.0) ) ? 1.0 : 0.0;\n";
				return "	T %TMP% = ( (%IN1% < EPSILON) != (%IN2% < EPSILON) ) ? 1.0 : 0.0;\n";
			case BITWAND:
				return "	T %TMP% = bwAnd(%IN1%, %IN2%);\n";
			case SEQ_RIX:
				return "\t\tT %TMP% = %IN1% + grix * %IN2%;\n"; //0-based global rix

			default:
				throw new RuntimeException("Invalid binary type: " + this.toString());
		}
	}