in src/main/java/org/apache/sysds/runtime/instructions/gpu/DnnGPUInstruction.java [720:888]
public void processInstruction(ExecutionContext ec) {
if (instOpcode.equalsIgnoreCase("bias_add") || instOpcode.equalsIgnoreCase("bias_multiply")) {
processBiasInstruction(instOpcode, ec);
return;
}
else if (instOpcode.equalsIgnoreCase("relu_backward")) {
processReLUBackwardInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("channel_sums")) {
processChannelSumsInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("update_nesterov_x")) {
processNesterovUpdateInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("lstm")) {
processLstmInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("lstm_backward")) {
processLstmBackwardInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
processBatchNorm2dInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
processBatchNorm2dBackwardInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d_test")) {
processBatchNorm2dTestInstruction(ec);
return;
}
else if (instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
processBatchNorm2dTrainInstruction(ec);
return;
}
GPUStatistics.incrementNoOfExecutedGPUInst();
int pad_h = getScalarInput(ec, _padding, 0);
int pad_w = getScalarInput(ec, _padding, 1);
int stride_h = getScalarInput(ec, _stride, 0);
int stride_w = getScalarInput(ec, _stride, 1);
int N = getScalarInput(ec, _input_shape, 0); //N
int C = getScalarInput(ec, _input_shape, 1); //C
int H = getScalarInput(ec, _input_shape, 2); //Hin
int W = getScalarInput(ec, _input_shape, 3); //Win
int K = getScalarInput(ec, _filter_shape, 0); //F = nrow(W)
int R = getScalarInput(ec, _filter_shape, 2); //Hf
int S = getScalarInput(ec, _filter_shape, 3); //Wf
int P = (int) DnnUtils.getP(H, R, stride_h, pad_h);
int Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w);
if (instOpcode.equalsIgnoreCase("conv2d")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input2.getName());
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S)
throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q);
LibMatrixCuDNN.conv2d(ec.getGPUContext(0), getExtendedOpcode(), image, filter, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
}
else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); //X
MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName()); //b
MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input3.getName()); //W
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S)
throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q);
LibMatrixCuDNN.conv2dBiasAdd(ec.getGPUContext(0), getExtendedOpcode(), image, bias, filter, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
}
else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in conv2d_backward_filter");
if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q)
throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_filter: " +
dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K*P*Q);
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), K, C * R * S);
LibMatrixCuDNN.conv2dBackwardFilter(ec.getGPUContext(0), getExtendedOpcode(), image, dout, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
// TODO: For now always copy the device data to host
// ec.gpuCtx.copyDeviceToHost(outputBlock);
}
else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S)
throw new DMLRuntimeException("Incorrect dimensions for filter in convolution_backward_data");
if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q)
throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_data: " +
dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K*P*Q);
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W);
LibMatrixCuDNN.conv2dBackwardData(ec.getGPUContext(0), getExtendedOpcode(), filter, dout, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget);
}
else if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling: " +
image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + C*H*W);
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * P * Q);
PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling") ? PoolingType.MAX : PoolingType.AVG;
LibMatrixCuDNN.pooling(ec.getGPUContext(0), getExtendedOpcode(), image, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget);
}
else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward")) {
MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName());
MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName());
MatrixObject maxPoolOutput = _input3 != null ? getMatrixInputForGPUInstruction(ec, _input3.getName()) : null;
if(dout.getNumRows() != N || dout.getNumColumns() != C*P*Q)
throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward");
if(image.getNumRows() != N || image.getNumColumns() != C*H*W)
throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling_backward: " +
image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + K*P*Q);
MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W);
PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling_backward") ? PoolingType.MAX : PoolingType.AVG;
LibMatrixCuDNN.poolingBackward(ec.getGPUContext(0), getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W,
K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget);
}
else {
throw new DMLRuntimeException("Unsupported GPU context for " + instOpcode);
}
// release inputs/outputs
ec.releaseMatrixInputForGPUInstruction(_input1.getName());
boolean isPool = instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling");
boolean isPoolBackward = instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward");
if ( !isPool )
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
if (instOpcode.equalsIgnoreCase("conv2d_bias_add") ||
(isPoolBackward && _input3 != null))
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
ec.releaseMatrixOutputForGPUInstruction(_output.getName());
}