in src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java [3384:3582]
public void incrementalAggregate(AggregateOperator aggOp, MatrixValue newWithCorrection) {
//assert(aggOp.correctionExists);
MatrixBlock newWithCor=checkType(newWithCorrection);
KahanObject buffer=new KahanObject(0, 0);
if(aggOp.correction==CorrectionLocationType.LASTROW)
{
if( aggOp.increOp.fn instanceof KahanPlus )
{
LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp);
}
else
{
for(int r=0; r<rlen-1; r++)
for(int c=0; c<clen; c++)
{
buffer._sum=this.get(r, c);
buffer._correction=this.get(r+1, c);
buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c),
newWithCor.get(r+1, c));
set(r, c, buffer._sum);
set(r+1, c, buffer._correction);
}
}
}
else if(aggOp.correction==CorrectionLocationType.LASTCOLUMN)
{
if(aggOp.increOp.fn instanceof Builtin
&& ( ((Builtin)(aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MAXINDEX
|| ((Builtin)(aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MININDEX)
){
// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
// rowIndexMax() and its siblings don't fit very well into the standard
// aggregate framework. We (ab)use the "correction factor" argument to
// hold the maximum value in each row/column.
// The execute() method for this aggregate takes as its argument
// two candidates for the highest value. Bookkeeping about
// indexes (return column/row index with highest value, breaking
// ties in favor of higher indexes) is handled in this function.
// Note that both versions of incrementalAggregate() contain
// very similar blocks of special-case code. If one block is
// modified, the other needs to be changed to match.
for(int r = 0; r < rlen; r++){
double currMaxValue = get(r, 1);
long newMaxIndex = (long)newWithCor.get(r, 0);
double newMaxValue = newWithCor.get(r, 1);
double update = aggOp.increOp.fn.execute(newMaxValue, currMaxValue);
if (2.0 == update) {
// Return value of 2 ==> both values the same, break ties
// in favor of higher index.
long curMaxIndex = (long) get(r,0);
set(r, 0, Math.max(curMaxIndex, newMaxIndex));
} else if(1.0 == update){
// Return value of 1 ==> new value is better; use its index
set(r, 0, newMaxIndex);
set(r, 1, newMaxValue);
} else {
// Other return value ==> current answer is best
}
}
// *** END HACK ***
}
else
{
if(aggOp.increOp.fn instanceof KahanPlus) {
LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp);
}
else {
for(int r=0; r<rlen; r++)
for(int c=0; c<clen-1; c++)
{
buffer._sum=this.get(r, c);
buffer._correction=this.get(r, c+1);
buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), newWithCor.get(r, c+1));
set(r, c, buffer._sum);
set(r, c+1, buffer._correction);
}
}
}
}
else if(aggOp.correction==CorrectionLocationType.LASTTWOROWS)
{
double n, n2, mu2;
for(int r=0; r<rlen-2; r++)
for(int c=0; c<clen; c++)
{
buffer._sum=this.get(r, c);
n=this.get(r+1, c);
buffer._correction=this.get(r+2, c);
mu2=newWithCor.get(r, c);
n2=newWithCor.get(r+1, c);
n=n+n2;
double toadd=(mu2-buffer._sum)*n2/n;
buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
set(r, c, buffer._sum);
set(r+1, c, n);
set(r+2, c, buffer._correction);
}
}else if(aggOp.correction==CorrectionLocationType.LASTTWOCOLUMNS)
{
double n, n2, mu2;
for(int r=0; r<rlen; r++)
for(int c=0; c<clen-2; c++)
{
buffer._sum=this.get(r, c);
n=this.get(r, c+1);
buffer._correction=this.get(r, c+2);
mu2=newWithCor.get(r, c);
n2=newWithCor.get(r, c+1);
n=n+n2;
double toadd=(mu2-buffer._sum)*n2/n;
buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
set(r, c, buffer._sum);
set(r, c+1, n);
set(r, c+2, buffer._correction);
}
}
else if (aggOp.correction == CorrectionLocationType.LASTFOURROWS
&& aggOp.increOp.fn instanceof CM
&& ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
// create buffers to store results
CM_COV_Object cbuff_curr = new CM_COV_Object();
CM_COV_Object cbuff_part = new CM_COV_Object();
// perform incremental aggregation
for (int r=0; r<rlen-4; r++)
for (int c=0; c<clen; c++) {
// extract current values: { var | mean, count, m2 correction, mean correction }
// note: m2 = var * (n - 1)
cbuff_curr.w = get(r+2, c); // count
cbuff_curr.m2._sum = get(r, c) * (cbuff_curr.w - 1); // m2
cbuff_curr.mean._sum = get(r+1, c); // mean
cbuff_curr.m2._correction = get(r+3, c);
cbuff_curr.mean._correction = get(r+4, c);
// extract partial values: { var | mean, count, m2 correction, mean correction }
// note: m2 = var * (n - 1)
cbuff_part.w = newWithCor.get(r+2, c); // count
cbuff_part.m2._sum = newWithCor.get(r, c) * (cbuff_part.w - 1); // m2
cbuff_part.mean._sum = newWithCor.get(r+1, c); // mean
cbuff_part.m2._correction = newWithCor.get(r+3, c);
cbuff_part.mean._correction = newWithCor.get(r+4, c);
// calculate incremental aggregated variance
cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
// store updated values: { var | mean, count, m2 correction, mean correction }
double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
set(r, c, var);
set(r+1, c, cbuff_curr.mean._sum); // mean
set(r+2, c, cbuff_curr.w); // count
set(r+3, c, cbuff_curr.m2._correction);
set(r+4, c, cbuff_curr.mean._correction);
}
}
else if (aggOp.correction == CorrectionLocationType.LASTFOURCOLUMNS
&& aggOp.increOp.fn instanceof CM
&& ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
// create buffers to store results
CM_COV_Object cbuff_curr = new CM_COV_Object();
CM_COV_Object cbuff_part = new CM_COV_Object();
// perform incremental aggregation
for (int r=0; r<rlen; r++)
for (int c=0; c<clen-4; c++) {
// extract current values: { var | mean, count, m2 correction, mean correction }
// note: m2 = var * (n - 1)
cbuff_curr.w = get(r, c+2); // count
cbuff_curr.m2._sum = get(r, c) * (cbuff_curr.w - 1); // m2
cbuff_curr.mean._sum = get(r, c+1); // mean
cbuff_curr.m2._correction = get(r, c+3);
cbuff_curr.mean._correction = get(r, c+4);
// extract partial values: { var | mean, count, m2 correction, mean correction }
// note: m2 = var * (n - 1)
cbuff_part.w = newWithCor.get(r, c+2); // count
cbuff_part.m2._sum = newWithCor.get(r, c) * (cbuff_part.w - 1); // m2
cbuff_part.mean._sum = newWithCor.get(r, c+1); // mean
cbuff_part.m2._correction = newWithCor.get(r, c+3);
cbuff_part.mean._correction = newWithCor.get(r, c+4);
// calculate incremental aggregated variance
cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
// store updated values: { var | mean, count, m2 correction, mean correction }
double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
set(r, c, var);
set(r, c+1, cbuff_curr.mean._sum); // mean
set(r, c+2, cbuff_curr.w); // count
set(r, c+3, cbuff_curr.m2._correction);
set(r, c+4, cbuff_curr.mean._correction);
}
}
else
throw new DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correction);
}