in src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java [540:717]
public void processInstruction(ExecutionContext ec) {
if (instOpcode.equalsIgnoreCase(Opcodes.BIAS_ADD.toString())) {
processBiasAddInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase(Opcodes.BIAS_MULTIPLY.toString())) {
processBiasMultiplyInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase(Opcodes.RELU_BACKWARD.toString())) {
processReluBackwardInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase(Opcodes.BATCH_NORM2D.toString())) {
processBatchNorm2dInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase(Opcodes.BATCH_NORM2D_BACKWARD.toString())) {
processBatchNorm2dBackwardInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase(Opcodes.LSTM.toString())) {
processLSTMInstruction(ec, false);
return;
}
else if (instOpcode.equalsIgnoreCase(Opcodes.LSTM_BACKWARD.toString())) {
processLSTMInstruction(ec, true);
return;
}
// acquire inputs
MatrixBlock outputBlock = null;
MatrixBlock matBlock = instOpcode.equalsIgnoreCase(Opcodes.AVGPOOLING_BACKWARD.toString()) ? null : ec.getMatrixInput(input1.getName());
int pad_h = getScalarInput(ec, _padding, 0);
int pad_w = getScalarInput(ec, _padding, 1);
int stride_h = getScalarInput(ec, _stride, 0);
int stride_w = getScalarInput(ec, _stride, 1);
int N = getScalarInput(ec, _input_shape, 0);
int C = getScalarInput(ec, _input_shape, 1);
int H = getScalarInput(ec, _input_shape, 2);
int W = getScalarInput(ec, _input_shape, 3);
int K = getScalarInput(ec, _filter_shape, 0);
int R = getScalarInput(ec, _filter_shape, 2);
int S = getScalarInput(ec, _filter_shape, 3);
int P = (int) DnnUtils.getP(H, R, stride_h, pad_h);
int Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w);
DnnParameters params = new DnnParameters(N, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, _numThreads);
params.enableNative = NativeHelper.isNativeLibraryLoaded();
if (instOpcode.equalsIgnoreCase(Opcodes.MAXPOOLING.toString()) || instOpcode.equalsIgnoreCase(Opcodes.RELU_MAXPOOLING.toString()) ||
instOpcode.equalsIgnoreCase(Opcodes.AVGPOOLING.toString())) {
if(matBlock.isEmpty()) {
outputBlock = new MatrixBlock(N, C*P*Q, true);
}
else {
outputBlock = new MatrixBlock(N, C*P*Q, false).allocateBlock();
PoolingType poolType = (instOpcode.equalsIgnoreCase(Opcodes.MAXPOOLING.toString()) || instOpcode.equalsIgnoreCase(Opcodes.RELU_MAXPOOLING.toString())) ? PoolingType.MAX : PoolingType.AVG;
if(instOpcode.equalsIgnoreCase(Opcodes.RELU_MAXPOOLING.toString()))
params.minValForMaxPoolOperations = 0;
LibMatrixDNN.pooling(matBlock, outputBlock, params, poolType);
}
}
else if (instOpcode.equalsIgnoreCase(Opcodes.MAXPOOLING_BACKWARD.toString()) || instOpcode.equalsIgnoreCase(Opcodes.RELU_MAXPOOLING_BACKWARD.toString()) ||
instOpcode.equalsIgnoreCase(Opcodes.AVGPOOLING_BACKWARD.toString())) {
MatrixBlock dout = ec.getMatrixInput(_in2.getName());
boolean isEmpty = instOpcode.equalsIgnoreCase(Opcodes.AVGPOOLING_BACKWARD.toString()) ? dout.isEmpty() : (matBlock.isEmpty() || dout.isEmpty());
if(isEmpty) {
outputBlock = new MatrixBlock(N, C*H*W, true);
}
else {
PoolingType poolType = (instOpcode.equalsIgnoreCase(Opcodes.MAXPOOLING_BACKWARD.toString()) || instOpcode.equalsIgnoreCase(Opcodes.RELU_MAXPOOLING_BACKWARD.toString())) ? PoolingType.MAX : PoolingType.AVG;
outputBlock = (poolType == PoolingType.MAX ) ? new MatrixBlock(N, C*H*W, true).allocateBlock() : new MatrixBlock(N, C*H*W, false).allocateBlock();
boolean performReLUBackward = instOpcode.equalsIgnoreCase(Opcodes.RELU_MAXPOOLING_BACKWARD.toString());
if(performReLUBackward)
params.minValForMaxPoolOperations = 0;
LibMatrixDNN.poolingBackward(matBlock, dout, outputBlock, params, performReLUBackward, poolType);
}
ec.releaseMatrixInput(_in2.getName());
}
else if (instOpcode.equalsIgnoreCase(Opcodes.CONV2D.toString())) {
resetNumThreads(params, C*R*S, P*Q, matBlock.getNonZeros() / (matBlock.getNumRows()*matBlock.getNumColumns()));
MatrixBlock filter = ec.getMatrixInput(_in2.getName());
if(filter.isEmpty() || matBlock.isEmpty()) {
outputBlock = new MatrixBlock(N, K*P*Q, true);
}
else {
boolean sparse = matBlock.isUltraSparse(false) && params.bias == null
&& matBlock.getInMemorySize() < MatrixBlock.estimateSizeDenseInMemory(N, K*P*Q);
outputBlock = new MatrixBlock(N, K*P*Q, sparse).allocateBlock();
if(params.enableNative && matBlock.isInSparseFormat())
matBlock.sparseToDense();
if(params.enableNative && !isFilterSparse(filter) && !matBlock.isInSparseFormat())
LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
else
LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
}
ec.releaseMatrixInput(_in2.getName());
}
else if (instOpcode.equalsIgnoreCase(Opcodes.CONV2D_BIAS_ADD.toString())) {
resetNumThreads(params, C*R*S, P*Q, matBlock.getNonZeros() / (matBlock.getNumRows()*matBlock.getNumColumns()));
MatrixBlock filter = ec.getMatrixInput(_in3.getName());
MatrixBlock bias = ec.getMatrixInput(_in2.getName());
if(bias.getNumRows() != params.K || bias.getNumColumns() != 1) {
throw new DMLRuntimeException("Incorrect shape of bias matrix: [" + bias.getNumRows() + " " + bias.getNumColumns() + "]. "
+ "Expected: [" + params.K + ", 1]");
}
boolean isOutputConvEmpty = filter.isEmpty() || matBlock.isEmpty();
if(isOutputConvEmpty && bias.isEmpty()) {
// bias_add(empty mb, empty mb) = empty mb
outputBlock = new MatrixBlock(N, K*P*Q, true);
}
else if(isOutputConvEmpty && !bias.isEmpty()) {
// Add bias to empty output block
// bias_add(empty mb, bias)
outputBlock = new MatrixBlock(N, K*P*Q, false).allocateBlock();
for(int n = 0; n < params.N; n++)
DnnUtils.fillBias(bias, outputBlock.getDenseBlockValues(),
n, n+1, params.N, params.K, params.P*params.Q);
}
else {
outputBlock = new MatrixBlock(N, K*P*Q, false).allocateBlock();
if(!bias.isEmpty()) {
// Handle situation where both input and filter are non empty, but bias is empty
params.bias = bias;
}
if(params.enableNative && matBlock.isInSparseFormat())
matBlock.sparseToDense();
if(params.enableNative && !isFilterSparse(filter) && !matBlock.isInSparseFormat())
LibMatrixNative.conv2d(matBlock, filter, outputBlock, params);
else
LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params);
}
ec.releaseMatrixInput(_in3.getName(), _in2.getName());
}
else if (instOpcode.equalsIgnoreCase(Opcodes.CONV2D_BACKWARD_FILTER.toString())) {
MatrixBlock dout = ec.getMatrixInput(_in2.getName());
if(dout.isEmpty() || matBlock.isEmpty()) {
outputBlock = new MatrixBlock(K, C*R*S, true);
}
else {
outputBlock = new MatrixBlock(K, C*R*S, false).allocateBlock();
if(params.enableNative && !matBlock.isInSparseFormat() && !dout.isInSparseFormat())
LibMatrixNative.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
else
LibMatrixDNN.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
}
ec.releaseMatrixInput(_in2.getName());
}
else if (instOpcode.equalsIgnoreCase(Opcodes.CONV2D_BACKWARD_DATA.toString())) {
MatrixBlock dout = ec.getMatrixInput(_in2.getName());
if(dout.isEmpty() || matBlock.isEmpty()) {
outputBlock = new MatrixBlock(N, C * H * W, true);
}
else {
outputBlock = new MatrixBlock(N, C * H * W, false).allocateBlock();
if(params.enableNative && !isFilterSparse(matBlock) && !dout.isInSparseFormat())
LibMatrixNative.conv2dBackwardData(matBlock, dout, outputBlock, params);
else
LibMatrixDNN.conv2dBackwardData(matBlock, dout, outputBlock, params);
}
ec.releaseMatrixInput(_in2.getName());
}
else {
throw new DMLRuntimeException("Unsupported op code " + instOpcode);
}
// release inputs/outputs
if(!instOpcode.equalsIgnoreCase(Opcodes.AVGPOOLING_BACKWARD.toString()))
ec.releaseMatrixInput(input1.getName());
ec.setMatrixOutput(getOutputVariableName(), outputBlock);
}