in src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java [3151:3381]
public void incrementalAggregate(AggregateOperator aggOp, MatrixValue correction, MatrixValue newWithCorrection, boolean deep) {
//assert(aggOp.correctionExists);
MatrixBlock cor=checkType(correction);
MatrixBlock newWithCor=checkType(newWithCorrection);
KahanObject buffer=new KahanObject(0, 0);
if(aggOp.correction==CorrectionLocationType.LASTROW) {
for(int r=0; r<rlen; r++)
for(int c=0; c<clen; c++)
{
buffer._sum=this.get(r, c);
buffer._correction=cor.get(0, c);
buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c),
newWithCor.get(r+1, c));
set(r, c, buffer._sum);
cor.set(0, c, buffer._correction);
}
}
else if(aggOp.correction==CorrectionLocationType.LASTCOLUMN) {
if(aggOp.increOp.fn instanceof Builtin
&& ( ((Builtin)(aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MAXINDEX
|| ((Builtin)(aggOp.increOp.fn)).bFunc == Builtin.BuiltinCode.MININDEX ) ) {
// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
// rowIndexMax() and its siblings don't fit very well into the standard
// aggregate framework. We (ab)use the "correction factor" argument to
// hold the maximum value in each row/column.
// The execute() method for this aggregate takes as its argument
// two candidates for the highest value. Bookkeeping about
// indexes (return column/row index with highest value, breaking
// ties in favor of higher indexes) is handled in this function.
// Note that both versions of incrementalAggregate() contain
// very similar blocks of special-case code. If one block is
// modified, the other needs to be changed to match.
for(int r=0; r<rlen; r++){
double currMaxValue = cor.get(r, 0);
long newMaxIndex = (long)newWithCor.get(r, 0);
double newMaxValue = newWithCor.get(r, 1);
double update = aggOp.increOp.fn.execute(newMaxValue, currMaxValue);
if (2.0 == update) {
// Return value of 2 ==> both values the same, break ties
// in favor of higher index.
long curMaxIndex = (long) get(r,0);
set(r, 0, Math.max(curMaxIndex, newMaxIndex));
} else if(1.0 == update){
// Return value of 1 ==> new value is better; use its index
set(r, 0, newMaxIndex);
cor.set(r, 0, newMaxValue);
} else {
// Other return value ==> current answer is best
}
}
// *** END HACK ***
}else{
for(int r=0; r<rlen; r++)
for(int c=0; c<clen; c++)
{
buffer._sum=this.get(r, c);
buffer._correction=cor.get(r, 0);
buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), newWithCor.get(r, c+1));
set(r, c, buffer._sum);
cor.set(r, 0, buffer._correction);
}
}
}
else if(aggOp.correction==CorrectionLocationType.NONE) {
//e.g., ak+ kahan plus as used in sum, mapmult, mmcj and tsmm
if(aggOp.increOp.fn instanceof KahanPlus) {
LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, cor, deep);
}
else
{
if( newWithCor.isInSparseFormat() && aggOp.sparseSafe ) //SPARSE
{
SparseBlock b = newWithCor.getSparseBlock();
if( b==null ) //early abort on empty block
return;
for( int r=0; r<Math.min(rlen, b.numRows()); r++ )
{
if( !b.isEmpty(r) )
{
int bpos = b.pos(r);
int blen = b.size(r);
int[] bix = b.indexes(r);
double[] bvals = b.values(r);
for( int j=bpos; j<bpos+blen; j++)
{
int c = bix[j];
buffer._sum = this.get(r, c);
buffer._correction = cor.get(r, c);
buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, bvals[j]);
set(r, c, buffer._sum);
cor.set(r, c, buffer._correction);
}
}
}
}
else //DENSE or SPARSE (!sparsesafe)
{
for(int r=0; r<rlen; r++)
for(int c=0; c<clen; c++) {
buffer._sum=this.get(r, c);
buffer._correction=cor.get(r, c);
buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c));
set(r, c, buffer._sum);
cor.set(r, c, buffer._correction);
}
}
//change representation if required
//(note since ak+ on blocks is currently only applied in MR, hence no need to account for this in mem estimates)
examSparsity();
}
}
else if(aggOp.correction==CorrectionLocationType.LASTTWOROWS) {
double n, n2, mu2;
for(int r=0; r<rlen; r++)
for(int c=0; c<clen; c++)
{
buffer._sum=this.get(r, c);
n=cor.get(0, c);
buffer._correction=cor.get(1, c);
mu2=newWithCor.get(r, c);
n2=newWithCor.get(r+1, c);
n=n+n2;
double toadd=(mu2-buffer._sum)*n2/n;
buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
set(r, c, buffer._sum);
cor.set(0, c, n);
cor.set(1, c, buffer._correction);
}
}
else if(aggOp.correction==CorrectionLocationType.LASTTWOCOLUMNS) {
double n, n2, mu2;
for(int r=0; r<rlen; r++)
for(int c=0; c<clen; c++)
{
buffer._sum=this.get(r, c);
n=cor.get(r, 0);
buffer._correction=cor.get(r, 1);
mu2=newWithCor.get(r, c);
n2=newWithCor.get(r, c+1);
n=n+n2;
double toadd=(mu2-buffer._sum)*n2/n;
buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, toadd);
set(r, c, buffer._sum);
cor.set(r, 0, n);
cor.set(r, 1, buffer._correction);
}
}
else if (aggOp.correction == CorrectionLocationType.LASTFOURROWS
&& aggOp.increOp.fn instanceof CM
&& ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
// create buffers to store results
CM_COV_Object cbuff_curr = new CM_COV_Object();
CM_COV_Object cbuff_part = new CM_COV_Object();
// perform incremental aggregation
for (int r=0; r<rlen; r++)
for (int c=0; c<clen; c++) {
// extract current values: { var | mean, count, m2 correction, mean correction }
// note: m2 = var * (n - 1)
cbuff_curr.w = cor.get(1, c); // count
cbuff_curr.m2._sum = get(r, c) * (cbuff_curr.w - 1); // m2
cbuff_curr.mean._sum = cor.get(0, c); // mean
cbuff_curr.m2._correction = cor.get(2, c);
cbuff_curr.mean._correction = cor.get(3, c);
// extract partial values: { var | mean, count, m2 correction, mean correction }
// note: m2 = var * (n - 1)
cbuff_part.w = newWithCor.get(r+2, c); // count
cbuff_part.m2._sum = newWithCor.get(r, c) * (cbuff_part.w - 1); // m2
cbuff_part.mean._sum = newWithCor.get(r+1, c); // mean
cbuff_part.m2._correction = newWithCor.get(r+3, c);
cbuff_part.mean._correction = newWithCor.get(r+4, c);
// calculate incremental aggregated variance
cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
// store updated values: { var | mean, count, m2 correction, mean correction }
double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
set(r, c, var);
cor.set(0, c, cbuff_curr.mean._sum); // mean
cor.set(1, c, cbuff_curr.w); // count
cor.set(2, c, cbuff_curr.m2._correction);
cor.set(3, c, cbuff_curr.mean._correction);
}
}
else if (aggOp.correction == CorrectionLocationType.LASTFOURCOLUMNS
&& aggOp.increOp.fn instanceof CM
&& ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
// create buffers to store results
CM_COV_Object cbuff_curr = new CM_COV_Object();
CM_COV_Object cbuff_part = new CM_COV_Object();
// perform incremental aggregation
for (int r=0; r<rlen; r++)
for (int c=0; c<clen; c++) {
// extract current values: { var | mean, count, m2 correction, mean correction }
// note: m2 = var * (n - 1)
cbuff_curr.w = cor.get(r, 1); // count
cbuff_curr.m2._sum = get(r, c) * (cbuff_curr.w - 1); // m2
cbuff_curr.mean._sum = cor.get(r, 0); // mean
cbuff_curr.m2._correction = cor.get(r, 2);
cbuff_curr.mean._correction = cor.get(r, 3);
// extract partial values: { var | mean, count, m2 correction, mean correction }
// note: m2 = var * (n - 1)
cbuff_part.w = newWithCor.get(r, c+2); // count
cbuff_part.m2._sum = newWithCor.get(r, c) * (cbuff_part.w - 1); // m2
cbuff_part.mean._sum = newWithCor.get(r, c+1); // mean
cbuff_part.m2._correction = newWithCor.get(r, c+3);
cbuff_part.mean._correction = newWithCor.get(r, c+4);
// calculate incremental aggregated variance
cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
// store updated values: { var | mean, count, m2 correction, mean correction }
double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
set(r, c, var);
cor.set(r, 0, cbuff_curr.mean._sum); // mean
cor.set(r, 1, cbuff_curr.w); // count
cor.set(r, 2, cbuff_curr.m2._correction);
cor.set(r, 3, cbuff_curr.mean._correction);
}
}
else
throw new DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correction);
}