public void incrementalAggregate()

in src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java [3151:3381]


	public void incrementalAggregate(AggregateOperator aggOp, MatrixValue correction, MatrixValue newWithCorrection, boolean deep) {
		//assert(aggOp.correctionExists); 
		MatrixBlock cor=checkType(correction);
		MatrixBlock newWithCor=checkType(newWithCorrection);
		KahanObject buffer=new KahanObject(0, 0);
		
		if(aggOp.correction==CorrectionLocationType.LASTROW) {
			for(int r=0; r<rlen; r++)
				for(int c=0; c<clen; c++)
				{
					buffer._sum=this.get(r, c);
					buffer._correction=cor.get(0, c);
					buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), 
							newWithCor.get(r+1, c));
					set(r, c, buffer._sum);
					cor.set(0, c, buffer._correction);
				}
		}
		else if(aggOp.correction==CorrectionLocationType.LASTCOLUMN) {
			if(aggOp.increOp.fn instanceof Builtin 
				&& ( ((Builtin)(aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MAXINDEX
					|| ((Builtin)(aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MININDEX ) ) {
					// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
					// rowIndexMax() and its siblings don't fit very well into the standard
					// aggregate framework. We (ab)use the "correction factor" argument to
					// hold the maximum value in each row/column.
				
					// The execute() method for this aggregate takes as its argument
					// two candidates for the highest value. Bookkeeping about
					// indexes (return column/row index with highest value, breaking
					// ties in favor of higher indexes) is handled in this function.
					// Note that both versions of incrementalAggregate() contain
					// very similar blocks of special-case code. If one block is
					// modified, the other needs to be changed to match.
					for(int r=0; r<rlen; r++){
						double currMaxValue = cor.get(r, 0);
						long newMaxIndex = (long)newWithCor.get(r, 0);
						double newMaxValue = newWithCor.get(r, 1);
						double update = aggOp.increOp.fn.execute(newMaxValue, currMaxValue);
						
						if (2.0 == update) {
							// Return value of 2 ==> both values the same, break ties
							// in favor of higher index.
							long curMaxIndex = (long) get(r,0);
							set(r, 0, Math.max(curMaxIndex, newMaxIndex));
						} else if(1.0 == update){
							// Return value of 1 ==> new value is better; use its index
							set(r, 0, newMaxIndex);
							cor.set(r, 0, newMaxValue);
						} else {
							// Other return value ==> current answer is best
						}
					}
					// *** END HACK ***
				}else{
					for(int r=0; r<rlen; r++)
						for(int c=0; c<clen; c++)
						{
							buffer._sum=this.get(r, c);
							buffer._correction=cor.get(r, 0);
							buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), newWithCor.get(r, c+1));
							set(r, c, buffer._sum);
							cor.set(r, 0, buffer._correction);
						}
				}
		}
		else if(aggOp.correction==CorrectionLocationType.NONE) {
			//e.g., ak+ kahan plus as used in sum, mapmult, mmcj and tsmm
			if(aggOp.increOp.fn instanceof KahanPlus) {
				LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, cor, deep);
			}
			else
			{
				if( newWithCor.isInSparseFormat() && aggOp.sparseSafe ) //SPARSE
				{
					SparseBlock b = newWithCor.getSparseBlock();
					if( b==null ) //early abort on empty block
						return;
					for( int r=0; r<Math.min(rlen, b.numRows()); r++ )
					{
						if( !b.isEmpty(r) ) 
						{
							int bpos = b.pos(r);
							int blen = b.size(r);
							int[] bix = b.indexes(r);
							double[] bvals = b.values(r);
							for( int j=bpos; j<bpos+blen; j++)
							{
								int c = bix[j];
								buffer._sum = this.get(r, c);
								buffer._correction = cor.get(r, c);
								buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, bvals[j]);
								set(r, c, buffer._sum);
								cor.set(r, c, buffer._correction);
							}
						}
					}

				}
				else //DENSE or SPARSE (!sparsesafe)
				{
					for(int r=0; r<rlen; r++)
						for(int c=0; c<clen; c++) {
							buffer._sum=this.get(r, c);
							buffer._correction=cor.get(r, c);
							buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c));
							set(r, c, buffer._sum);
							cor.set(r, c, buffer._correction);
						}
				}

				//change representation if required
				//(note since ak+ on blocks is currently only applied in MR, hence no need to account for this in mem estimates)
				examSparsity(); 
			}
		}
		else if(aggOp.correction==CorrectionLocationType.LASTTWOROWS) {
			double n, n2, mu2;
			for(int r=0; r<rlen; r++)
				for(int c=0; c<clen; c++)
				{
					buffer._sum=this.get(r, c);
					n=cor.get(0, c);
					buffer._correction=cor.get(1, c);
					mu2=newWithCor.get(r, c);
					n2=newWithCor.get(r+1, c);
					n=n+n2;
					double toadd=(mu2-buffer._sum)*n2/n;
					buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
					set(r, c, buffer._sum);
					cor.set(0, c, n);
					cor.set(1, c, buffer._correction);
				}
		}
		else if(aggOp.correction==CorrectionLocationType.LASTTWOCOLUMNS) {
			double n, n2, mu2;
			for(int r=0; r<rlen; r++)
				for(int c=0; c<clen; c++)
				{
					buffer._sum=this.get(r, c);
					n=cor.get(r, 0);
					buffer._correction=cor.get(r, 1);
					mu2=newWithCor.get(r, c);
					n2=newWithCor.get(r, c+1);
					n=n+n2;
					double toadd=(mu2-buffer._sum)*n2/n;
					buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
					set(r, c, buffer._sum);
					cor.set(r, 0, n);
					cor.set(r, 1, buffer._correction);
				}
		}
		else if (aggOp.correction == CorrectionLocationType.LASTFOURROWS
				&& aggOp.increOp.fn instanceof CM
				&& ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
			// create buffers to store results
			CM_COV_Object cbuff_curr = new CM_COV_Object();
			CM_COV_Object cbuff_part = new CM_COV_Object();

			// perform incremental aggregation
			for (int r=0; r<rlen; r++)
				for (int c=0; c<clen; c++) {
					// extract current values: { var | mean, count, m2 correction, mean correction }
					// note: m2 = var * (n - 1)
					cbuff_curr.w = cor.get(1, c); // count
					cbuff_curr.m2._sum = get(r, c) * (cbuff_curr.w - 1); // m2
					cbuff_curr.mean._sum = cor.get(0, c); // mean
					cbuff_curr.m2._correction = cor.get(2, c);
					cbuff_curr.mean._correction = cor.get(3, c);

					// extract partial values: { var | mean, count, m2 correction, mean correction }
					// note: m2 = var * (n - 1)
					cbuff_part.w = newWithCor.get(r+2, c); // count
					cbuff_part.m2._sum = newWithCor.get(r, c) * (cbuff_part.w - 1); // m2
					cbuff_part.mean._sum = newWithCor.get(r+1, c); // mean
					cbuff_part.m2._correction = newWithCor.get(r+3, c);
					cbuff_part.mean._correction = newWithCor.get(r+4, c);

					// calculate incremental aggregated variance
					cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);

					// store updated values: { var | mean, count, m2 correction, mean correction }
					double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
					set(r, c, var);
					cor.set(0, c, cbuff_curr.mean._sum); // mean
					cor.set(1, c, cbuff_curr.w); // count
					cor.set(2, c, cbuff_curr.m2._correction);
					cor.set(3, c, cbuff_curr.mean._correction);
				}
		}
		else if (aggOp.correction == CorrectionLocationType.LASTFOURCOLUMNS
				&& aggOp.increOp.fn instanceof CM
				&& ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
			// create buffers to store results
			CM_COV_Object cbuff_curr = new CM_COV_Object();
			CM_COV_Object cbuff_part = new CM_COV_Object();

			// perform incremental aggregation
			for (int r=0; r<rlen; r++)
				for (int c=0; c<clen; c++) {
					// extract current values: { var | mean, count, m2 correction, mean correction }
					// note: m2 = var * (n - 1)
					cbuff_curr.w = cor.get(r, 1); // count
					cbuff_curr.m2._sum = get(r, c) * (cbuff_curr.w - 1); // m2
					cbuff_curr.mean._sum = cor.get(r, 0); // mean
					cbuff_curr.m2._correction = cor.get(r, 2);
					cbuff_curr.mean._correction = cor.get(r, 3);

					// extract partial values: { var | mean, count, m2 correction, mean correction }
					// note: m2 = var * (n - 1)
					cbuff_part.w = newWithCor.get(r, c+2); // count
					cbuff_part.m2._sum = newWithCor.get(r, c) * (cbuff_part.w - 1); // m2
					cbuff_part.mean._sum = newWithCor.get(r, c+1); // mean
					cbuff_part.m2._correction = newWithCor.get(r, c+3);
					cbuff_part.mean._correction = newWithCor.get(r, c+4);

					// calculate incremental aggregated variance
					cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);

					// store updated values: { var | mean, count, m2 correction, mean correction }
					double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
					set(r, c, var);
					cor.set(r, 0, cbuff_curr.mean._sum); // mean
					cor.set(r, 1, cbuff_curr.w); // count
					cor.set(r, 2, cbuff_curr.m2._correction);
					cor.set(r, 3, cbuff_curr.mean._correction);
				}
		}
		else
			throw new DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correction);
	}