in src/main/java/org/apache/sysds/hops/AggBinaryOp.java [1050:1192]
private MMultMethod optFindMMultMethodSpark( long m1_rows, long m1_cols, long m1_blen, long m1_nnz,
long m2_rows, long m2_cols, long m2_blen, long m2_nnz,
MMTSJType mmtsj, ChainType chainType, boolean leftPMInput, boolean tmmRewrite )
{
//Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
//and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
//required because the max_int byte buffer constraint has been fixed in Spark 1.4
double memBudgetExec = MAPMULT_MEM_MULTIPLIER * SparkExecutionContext.getBroadcastMemoryBudget();
double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
//reset spark broadcast memory information (for concurrent parfor jobs, awareness of additional
//cp memory requirements on spark rdd operations with broadcasts)
_spBroadcastMemEstimate = 0;
// Step 0: check for forced mmultmethod
if( FORCED_MMULT_METHOD !=null )
return FORCED_MMULT_METHOD;
// Step 1: check TSMM
// If transpose self pattern and result is single block:
// use specialized TSMM method (always better than generic jobs)
if( ( mmtsj == MMTSJType.LEFT && m2_cols>=0 && m2_cols <= m2_blen )
|| ( mmtsj == MMTSJType.RIGHT && m1_rows>=0 && m1_rows <= m1_blen ) )
{
return MMultMethod.TSMM;
}
// Step 2: check MapMMChain
// If mapmultchain pattern and result is a single block:
// use specialized mapmult method
if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES )
{
//matmultchain if dim2(X)<=blocksize and all vectors fit in mappers
//(X: m1_cols x m1_rows, v: m1_rows x m2_cols, w: m1_cols x m2_cols)
//NOTE: generalization possibe: m2_cols>=0 && m2_cols<=m2_cpb
if( chainType!=ChainType.NONE && m1_rows >=0 && m1_rows <= m1_blen && m2_cols==1 )
{
if( chainType==ChainType.XtXv && m1_rows>=0 && m2_cols>=0
&& OptimizerUtils.estimateSize(m1_rows, m2_cols ) < memBudgetExec )
{
return MMultMethod.MAPMM_CHAIN;
}
else if( (chainType==ChainType.XtwXv || chainType==ChainType.XtXvy )
&& m1_rows>=0 && m2_cols>=0 && m1_cols>=0
&& OptimizerUtils.estimateSize(m1_rows, m2_cols)
+ OptimizerUtils.estimateSize(m1_cols, m2_cols) < memBudgetExec
&& 2*(OptimizerUtils.estimateSize(m1_rows, m2_cols)
+ OptimizerUtils.estimateSize(m1_cols, m2_cols)) < memBudgetLocal )
{
_spBroadcastMemEstimate = 2*(OptimizerUtils.estimateSize(m1_rows, m2_cols)
+ OptimizerUtils.estimateSize(m1_cols, m2_cols));
return MMultMethod.MAPMM_CHAIN;
}
}
}
// Step 3: check for PMM (permutation matrix needs to fit into mapper memory)
// (needs to be checked before mapmult for consistency with removeEmpty compilation
double footprintPM1 = getMapmmMemEstimate(m1_rows, 1, m1_blen, m1_nnz, m2_rows, m2_cols, m2_blen, m2_nnz, 1, true);
double footprintPM2 = getMapmmMemEstimate(m2_rows, 1, m1_blen, m1_nnz, m2_rows, m2_cols, m2_blen, m2_nnz, 1, true);
if( (footprintPM1 < memBudgetExec && m1_rows>=0 || footprintPM2 < memBudgetExec && m2_rows>=0)
&& 2*OptimizerUtils.estimateSize(m1_rows, 1) < memBudgetLocal
&& leftPMInput )
{
_spBroadcastMemEstimate = 2*OptimizerUtils.estimateSize(m1_rows, 1);
return MMultMethod.PMM;
}
// Step 4: check MapMM
// If the size of one input is small, choose a method that uses broadcast variables to prevent shuffle
//memory estimates for local partitioning (mb -> partitioned mb)
double m1Size = OptimizerUtils.estimateSizeExactSparsity(m1_rows, m1_cols, m1_nnz); //m1 single block
double m2Size = OptimizerUtils.estimateSizeExactSparsity(m2_rows, m2_cols, m2_nnz); //m2 single block
double m1SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_rows, m1_cols, m1_blen, m1_nnz); //m1 partitioned
double m2SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_rows, m2_cols, m2_blen, m2_nnz); //m2 partitioned
//memory estimates for remote execution (broadcast and outputs)
double footprint1 = getMapmmMemEstimate(m1_rows, m1_cols, m1_blen, m1_nnz, m2_rows, m2_cols, m2_blen, m2_nnz, 1, false);
double footprint2 = getMapmmMemEstimate(m1_rows, m1_cols, m1_blen, m1_nnz, m2_rows, m2_cols, m2_blen, m2_nnz, 2, false);
if ( (footprint1 < memBudgetExec && m1Size+m1SizeP < memBudgetLocal && m1_rows>=0 && m1_cols>=0)
|| (footprint2 < memBudgetExec && m2Size+m2SizeP < memBudgetLocal && m2_rows>=0 && m2_cols>=0) )
{
//apply map mult if one side fits in remote task memory
//(if so pick smaller input for distributed cache)
//TODO relax requirement of valid CP dimensions once we support broadcast creation from files/RDDs
double em1Size = getInput().get(0).getOutputMemEstimate(); //w/ worst-case estimate
double em2Size = getInput().get(1).getOutputMemEstimate(); //w/ worst-case estimate
if( (m1SizeP < m2SizeP || (m1SizeP==m2SizeP && em1Size<em2Size) )
&& m1_rows>=0 && m1_cols>=0
&& OptimizerUtils.isValidCPDimensions(m1_rows, m1_cols) ) {
_spBroadcastMemEstimate = m1Size+m1SizeP;
return MMultMethod.MAPMM_L;
}
else if( OptimizerUtils.isValidCPDimensions(m2_rows, m2_cols) ) {
_spBroadcastMemEstimate = m2Size+m2SizeP;
return MMultMethod.MAPMM_R;
}
}
// Step 5: check for TSMM2 (2 pass w/o suffle, preferred over CPMM/RMM)
if( mmtsj != MMTSJType.NONE && m1_rows >=0 && m1_cols>=0
&& m2_rows >= 0 && m2_cols>=0 )
{
double mSize = (mmtsj == MMTSJType.LEFT) ?
OptimizerUtils.estimateSizeExactSparsity(m2_rows, m2_cols-m2_blen, 1.0) :
OptimizerUtils.estimateSizeExactSparsity(m1_rows-m1_blen, m1_cols, 1.0);
double mSizeP = (mmtsj == MMTSJType.LEFT) ?
OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_rows, m2_cols-m2_blen, m2_blen, 1.0) :
OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_rows-m1_blen, m1_cols, m1_blen, 1.0);
if( mSizeP < memBudgetExec && mSize+mSizeP < memBudgetLocal
&& ((mmtsj == MMTSJType.LEFT) ? m2_cols<=2*m2_blen : m1_rows<=2*m1_blen) //4 output blocks
&& mSizeP < 2L*1024*1024*1024) { //2GB limitation as single broadcast
return MMultMethod.TSMM2;
}
}
// Step 6: check for unknowns
// If the dimensions are unknown at compilation time, simply assume
// the worst-case scenario and produce the most robust plan -- which is CPMM
if ( m1_rows == -1 || m1_cols == -1 || m2_rows == -1 || m2_cols == -1 )
return MMultMethod.CPMM;
// Step 7: check for ZIPMM
// If t(X)%*%y -> t(t(y)%*%X) rewrite and ncol(X)<blocksize
if( tmmRewrite && m1_rows >= 0 && m1_rows <= m1_blen //blocksize constraint left
&& m2_cols >= 0 && m2_cols <= m2_blen ) //blocksize constraint right
{
return MMultMethod.ZIPMM;
}
// Step 8: Decide CPMM vs RMM based on io costs
//estimate shuffle costs weighted by parallelism
//TODO currently we reuse the mr estimates, these need to be fine-tune for our spark operators
double rmm_costs = getRMMCostEstimate(m1_rows, m1_cols, m1_blen, m2_rows, m2_cols, m2_blen);
double cpmm_costs = getCPMMCostEstimate(m1_rows, m1_cols, m1_blen, m2_rows, m2_cols, m2_blen);
//final mmult method decision
if ( cpmm_costs < rmm_costs )
return MMultMethod.CPMM;
return MMultMethod.RMM;
}