private MMultMethod optFindMMultMethodSpark()

in src/main/java/org/apache/sysds/hops/AggBinaryOp.java [1050:1192]


	private MMultMethod optFindMMultMethodSpark( long m1_rows, long m1_cols, long m1_blen, long m1_nnz, 
		long m2_rows, long m2_cols, long m2_blen, long m2_nnz,
		MMTSJType mmtsj, ChainType chainType, boolean leftPMInput, boolean tmmRewrite ) 
	{
		//Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
		//and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
		//required because the max_int byte buffer constraint has been fixed in Spark 1.4 
		double memBudgetExec = MAPMULT_MEM_MULTIPLIER * SparkExecutionContext.getBroadcastMemoryBudget();
		double memBudgetLocal = OptimizerUtils.getLocalMemBudget();

		//reset spark broadcast memory information (for concurrent parfor jobs, awareness of additional 
		//cp memory requirements on spark rdd operations with broadcasts)
		_spBroadcastMemEstimate = 0;
		
		// Step 0: check for forced mmultmethod
		if( FORCED_MMULT_METHOD !=null )
			return FORCED_MMULT_METHOD;
		
		// Step 1: check TSMM
		// If transpose self pattern and result is single block:
		// use specialized TSMM method (always better than generic jobs)
		if(    ( mmtsj == MMTSJType.LEFT && m2_cols>=0 && m2_cols <= m2_blen )
			|| ( mmtsj == MMTSJType.RIGHT && m1_rows>=0 && m1_rows <= m1_blen ) )
		{
			return MMultMethod.TSMM;
		}
		
		// Step 2: check MapMMChain
		// If mapmultchain pattern and result is a single block:
		// use specialized mapmult method
		if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES )
		{
			//matmultchain if dim2(X)<=blocksize and all vectors fit in mappers
			//(X: m1_cols x m1_rows, v: m1_rows x m2_cols, w: m1_cols x m2_cols) 
			//NOTE: generalization possibe: m2_cols>=0 && m2_cols<=m2_cpb
			if( chainType!=ChainType.NONE && m1_rows >=0 && m1_rows <= m1_blen && m2_cols==1 )
			{
				if( chainType==ChainType.XtXv && m1_rows>=0 && m2_cols>=0 
					&& OptimizerUtils.estimateSize(m1_rows, m2_cols ) < memBudgetExec )
				{
					return MMultMethod.MAPMM_CHAIN;
				}
				else if( (chainType==ChainType.XtwXv || chainType==ChainType.XtXvy ) 
					&& m1_rows>=0 && m2_cols>=0 && m1_cols>=0
					&&   OptimizerUtils.estimateSize(m1_rows, m2_cols) 
					   + OptimizerUtils.estimateSize(m1_cols, m2_cols) < memBudgetExec
					&& 2*(OptimizerUtils.estimateSize(m1_rows, m2_cols) 
					   + OptimizerUtils.estimateSize(m1_cols, m2_cols)) < memBudgetLocal )
				{
					_spBroadcastMemEstimate = 2*(OptimizerUtils.estimateSize(m1_rows, m2_cols) 
						+ OptimizerUtils.estimateSize(m1_cols, m2_cols));
					return MMultMethod.MAPMM_CHAIN;
				}
			}
		}
		
		// Step 3: check for PMM (permutation matrix needs to fit into mapper memory)
		// (needs to be checked before mapmult for consistency with removeEmpty compilation 
		double footprintPM1 = getMapmmMemEstimate(m1_rows, 1, m1_blen, m1_nnz, m2_rows, m2_cols, m2_blen, m2_nnz, 1, true);
		double footprintPM2 = getMapmmMemEstimate(m2_rows, 1, m1_blen, m1_nnz, m2_rows, m2_cols, m2_blen, m2_nnz, 1, true);
		if( (footprintPM1 < memBudgetExec && m1_rows>=0 || footprintPM2 < memBudgetExec && m2_rows>=0)
			&& 2*OptimizerUtils.estimateSize(m1_rows, 1) < memBudgetLocal
			&& leftPMInput ) 
		{
			_spBroadcastMemEstimate = 2*OptimizerUtils.estimateSize(m1_rows, 1);
			return MMultMethod.PMM;
		}
		
		// Step 4: check MapMM
		// If the size of one input is small, choose a method that uses broadcast variables to prevent shuffle
		
		//memory estimates for local partitioning (mb -> partitioned mb)
		double m1Size = OptimizerUtils.estimateSizeExactSparsity(m1_rows, m1_cols, m1_nnz); //m1 single block
		double m2Size = OptimizerUtils.estimateSizeExactSparsity(m2_rows, m2_cols, m2_nnz); //m2 single block
		double m1SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_rows, m1_cols, m1_blen, m1_nnz); //m1 partitioned 
		double m2SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_rows, m2_cols, m2_blen, m2_nnz); //m2 partitioned
		
		//memory estimates for remote execution (broadcast and outputs)
		double footprint1 = getMapmmMemEstimate(m1_rows, m1_cols, m1_blen, m1_nnz, m2_rows, m2_cols, m2_blen, m2_nnz, 1, false);
		double footprint2 = getMapmmMemEstimate(m1_rows, m1_cols, m1_blen, m1_nnz, m2_rows, m2_cols, m2_blen, m2_nnz, 2, false);		
		
		if (   (footprint1 < memBudgetExec && m1Size+m1SizeP < memBudgetLocal && m1_rows>=0 && m1_cols>=0)
			|| (footprint2 < memBudgetExec && m2Size+m2SizeP < memBudgetLocal && m2_rows>=0 && m2_cols>=0) ) 
		{
			//apply map mult if one side fits in remote task memory 
			//(if so pick smaller input for distributed cache)
			//TODO relax requirement of valid CP dimensions once we support broadcast creation from files/RDDs
			double em1Size = getInput().get(0).getOutputMemEstimate(); //w/ worst-case estimate
			double em2Size = getInput().get(1).getOutputMemEstimate(); //w/ worst-case estimate
			if( (m1SizeP < m2SizeP || (m1SizeP==m2SizeP && em1Size<em2Size) )
				&& m1_rows>=0 && m1_cols>=0
				&& OptimizerUtils.isValidCPDimensions(m1_rows, m1_cols) ) {
				_spBroadcastMemEstimate = m1Size+m1SizeP;
				return MMultMethod.MAPMM_L;
			}
			else if( OptimizerUtils.isValidCPDimensions(m2_rows, m2_cols) ) {
				_spBroadcastMemEstimate = m2Size+m2SizeP;
				return MMultMethod.MAPMM_R;
			}
		}
		
		// Step 5: check for TSMM2 (2 pass w/o suffle, preferred over CPMM/RMM)
		if( mmtsj != MMTSJType.NONE && m1_rows >=0 && m1_cols>=0 
			&& m2_rows >= 0 && m2_cols>=0 )
		{
			double mSize = (mmtsj == MMTSJType.LEFT) ? 
					OptimizerUtils.estimateSizeExactSparsity(m2_rows, m2_cols-m2_blen, 1.0) : 
					OptimizerUtils.estimateSizeExactSparsity(m1_rows-m1_blen, m1_cols, 1.0);
			double mSizeP = (mmtsj == MMTSJType.LEFT) ? 
					OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_rows, m2_cols-m2_blen, m2_blen, 1.0) : 
					OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_rows-m1_blen, m1_cols, m1_blen, 1.0); 
			if( mSizeP < memBudgetExec && mSize+mSizeP < memBudgetLocal 
				&& ((mmtsj == MMTSJType.LEFT) ? m2_cols<=2*m2_blen : m1_rows<=2*m1_blen) //4 output blocks
				&& mSizeP < 2L*1024*1024*1024) { //2GB limitation as single broadcast
				return MMultMethod.TSMM2;
			}
		}
		
		// Step 6: check for unknowns
		// If the dimensions are unknown at compilation time, simply assume 
		// the worst-case scenario and produce the most robust plan -- which is CPMM
		if ( m1_rows == -1 || m1_cols == -1 || m2_rows == -1 || m2_cols == -1 )
			return MMultMethod.CPMM;

		// Step 7: check for ZIPMM
		// If t(X)%*%y -> t(t(y)%*%X) rewrite and ncol(X)<blocksize
		if( tmmRewrite && m1_rows >= 0 && m1_rows <= m1_blen  //blocksize constraint left
			&& m2_cols >= 0 && m2_cols <= m2_blen )           //blocksize constraint right
		{
			return MMultMethod.ZIPMM;
		}
		
		// Step 8: Decide CPMM vs RMM based on io costs
		//estimate shuffle costs weighted by parallelism
		//TODO currently we reuse the mr estimates, these need to be fine-tune for our spark operators
		double rmm_costs = getRMMCostEstimate(m1_rows, m1_cols, m1_blen, m2_rows, m2_cols, m2_blen);
		double cpmm_costs = getCPMMCostEstimate(m1_rows, m1_cols, m1_blen, m2_rows, m2_cols, m2_blen);
		
		//final mmult method decision 
		if ( cpmm_costs < rmm_costs ) 
			return MMultMethod.CPMM;
		return MMultMethod.RMM;
	}