public void incrementalAggregate()

in src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java [3384:3582]


	public void incrementalAggregate(AggregateOperator aggOp, MatrixValue newWithCorrection) {
		//assert(aggOp.correctionExists);
		MatrixBlock newWithCor=checkType(newWithCorrection);
		KahanObject buffer=new KahanObject(0, 0);
		
		if(aggOp.correction==CorrectionLocationType.LASTROW)
		{
			if( aggOp.increOp.fn instanceof KahanPlus )
			{
				LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp);
			}
			else
			{
				for(int r=0; r<rlen-1; r++)
					for(int c=0; c<clen; c++)
					{
						buffer._sum=this.get(r, c);
						buffer._correction=this.get(r+1, c);
						buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), 
								newWithCor.get(r+1, c));
						set(r, c, buffer._sum);
						set(r+1, c, buffer._correction);
					}
			}
		}
		else if(aggOp.correction==CorrectionLocationType.LASTCOLUMN)
		{
			if(aggOp.increOp.fn instanceof Builtin 
			   && ( ((Builtin)(aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MAXINDEX 
			        || ((Builtin)(aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MININDEX) 
			        ){
				// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
				// rowIndexMax() and its siblings don't fit very well into the standard
				// aggregate framework. We (ab)use the "correction factor" argument to
				// hold the maximum value in each row/column.
			
				// The execute() method for this aggregate takes as its argument
				// two candidates for the highest value. Bookkeeping about
				// indexes (return column/row index with highest value, breaking
				// ties in favor of higher indexes) is handled in this function.
				// Note that both versions of incrementalAggregate() contain
				// very similar blocks of special-case code. If one block is
				// modified, the other needs to be changed to match.
				for(int r = 0; r < rlen; r++){
					double currMaxValue = get(r, 1);
					long newMaxIndex = (long)newWithCor.get(r, 0);
					double newMaxValue = newWithCor.get(r, 1);
					double update = aggOp.increOp.fn.execute(newMaxValue, currMaxValue);
	
					if (2.0 == update) {
						// Return value of 2 ==> both values the same, break ties
						// in favor of higher index.
						long curMaxIndex = (long) get(r,0);
						set(r, 0, Math.max(curMaxIndex, newMaxIndex));
					} else if(1.0 == update){
						// Return value of 1 ==> new value is better; use its index
						set(r, 0, newMaxIndex);
						set(r, 1, newMaxValue);
					} else {
						// Other return value ==> current answer is best
					}
				}
				// *** END HACK ***
			}
			else
			{
				if(aggOp.increOp.fn instanceof KahanPlus) {
					LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp);
				}
				else {
					for(int r=0; r<rlen; r++)
						for(int c=0; c<clen-1; c++)
						{
							buffer._sum=this.get(r, c);
							buffer._correction=this.get(r, c+1);
							buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), newWithCor.get(r, c+1));
							set(r, c, buffer._sum);
							set(r, c+1, buffer._correction);
						}
				}
			}
		}
		else if(aggOp.correction==CorrectionLocationType.LASTTWOROWS)
		{
			double n, n2, mu2;
			for(int r=0; r<rlen-2; r++)
				for(int c=0; c<clen; c++)
				{
					buffer._sum=this.get(r, c);
					n=this.get(r+1, c);
					buffer._correction=this.get(r+2, c);
					mu2=newWithCor.get(r, c);
					n2=newWithCor.get(r+1, c);
					n=n+n2;
					double toadd=(mu2-buffer._sum)*n2/n;
					buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
					set(r, c, buffer._sum);
					set(r+1, c, n);
					set(r+2, c, buffer._correction);
				}
			
		}else if(aggOp.correction==CorrectionLocationType.LASTTWOCOLUMNS)
		{
			double n, n2, mu2;
			for(int r=0; r<rlen; r++)
				for(int c=0; c<clen-2; c++)
				{
					buffer._sum=this.get(r, c);
					n=this.get(r, c+1);
					buffer._correction=this.get(r, c+2);
					mu2=newWithCor.get(r, c);
					n2=newWithCor.get(r, c+1);
					n=n+n2;
					double toadd=(mu2-buffer._sum)*n2/n;
					buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
					set(r, c, buffer._sum);
					set(r, c+1, n);
					set(r, c+2, buffer._correction);
				}
		}
		else if (aggOp.correction == CorrectionLocationType.LASTFOURROWS
				&& aggOp.increOp.fn instanceof CM
				&& ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
			// create buffers to store results
			CM_COV_Object cbuff_curr = new CM_COV_Object();
			CM_COV_Object cbuff_part = new CM_COV_Object();

			// perform incremental aggregation
			for (int r=0; r<rlen-4; r++)
				for (int c=0; c<clen; c++) {
					// extract current values: { var | mean, count, m2 correction, mean correction }
					// note: m2 = var * (n - 1)
					cbuff_curr.w = get(r+2, c); // count
					cbuff_curr.m2._sum = get(r, c) * (cbuff_curr.w - 1); // m2
					cbuff_curr.mean._sum = get(r+1, c); // mean
					cbuff_curr.m2._correction = get(r+3, c);
					cbuff_curr.mean._correction = get(r+4, c);

					// extract partial values: { var | mean, count, m2 correction, mean correction }
					// note: m2 = var * (n - 1)
					cbuff_part.w = newWithCor.get(r+2, c); // count
					cbuff_part.m2._sum = newWithCor.get(r, c) * (cbuff_part.w - 1); // m2
					cbuff_part.mean._sum = newWithCor.get(r+1, c); // mean
					cbuff_part.m2._correction = newWithCor.get(r+3, c);
					cbuff_part.mean._correction = newWithCor.get(r+4, c);

					// calculate incremental aggregated variance
					cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);

					// store updated values: { var | mean, count, m2 correction, mean correction }
					double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
					set(r, c, var);
					set(r+1, c, cbuff_curr.mean._sum); // mean
					set(r+2, c, cbuff_curr.w); // count
					set(r+3, c, cbuff_curr.m2._correction);
					set(r+4, c, cbuff_curr.mean._correction);
				}
		}
		else if (aggOp.correction == CorrectionLocationType.LASTFOURCOLUMNS
				&& aggOp.increOp.fn instanceof CM
				&& ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
			// create buffers to store results
			CM_COV_Object cbuff_curr = new CM_COV_Object();
			CM_COV_Object cbuff_part = new CM_COV_Object();

			// perform incremental aggregation
			for (int r=0; r<rlen; r++)
				for (int c=0; c<clen-4; c++) {
					// extract current values: { var | mean, count, m2 correction, mean correction }
					// note: m2 = var * (n - 1)
					cbuff_curr.w = get(r, c+2); // count
					cbuff_curr.m2._sum = get(r, c) * (cbuff_curr.w - 1); // m2
					cbuff_curr.mean._sum = get(r, c+1); // mean
					cbuff_curr.m2._correction = get(r, c+3);
					cbuff_curr.mean._correction = get(r, c+4);

					// extract partial values: { var | mean, count, m2 correction, mean correction }
					// note: m2 = var * (n - 1)
					cbuff_part.w = newWithCor.get(r, c+2); // count
					cbuff_part.m2._sum = newWithCor.get(r, c) * (cbuff_part.w - 1); // m2
					cbuff_part.mean._sum = newWithCor.get(r, c+1); // mean
					cbuff_part.m2._correction = newWithCor.get(r, c+3);
					cbuff_part.mean._correction = newWithCor.get(r, c+4);

					// calculate incremental aggregated variance
					cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);

					// store updated values: { var | mean, count, m2 correction, mean correction }
					double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
					set(r, c, var);
					set(r, c+1, cbuff_curr.mean._sum); // mean
					set(r, c+2, cbuff_curr.w); // count
					set(r, c+3, cbuff_curr.m2._correction);
					set(r, c+4, cbuff_curr.mean._correction);
				}
		}
		else
			throw new DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correction);
	}