in src/main/java/org/apache/sysds/runtime/instructions/gpu/DnnGPUInstruction.java [166:377]
public static DnnGPUInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
if( ( opcode.equalsIgnoreCase("conv2d")
|| opcode.equalsIgnoreCase("conv2d_backward_filter")
|| opcode.equalsIgnoreCase("conv2d_backward_data")) ) {
InstructionUtils.checkNumFields(parts, 16);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[15]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[3]));
stride.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
return new DnnGPUInstruction(in1, in2, out, opcode, str, stride,
padding, input_shape, filter_shape, Double.parseDouble(parts[16]));
}
else if( opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("avgpooling_backward") ) {
boolean withMaxPoolOut = false;
if(parts.length == 18) {
withMaxPoolOut = true;
}
else
InstructionUtils.checkNumFields(parts, 16);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = withMaxPoolOut ? new CPOperand(parts[15]) : null;
CPOperand out = withMaxPoolOut ? new CPOperand(parts[16]) : new CPOperand(parts[15]);
double memBudget = withMaxPoolOut ? Double.parseDouble(parts[17]) : Double.parseDouble(parts[16]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[3]));
stride.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride,
padding, input_shape, filter_shape, memBudget);
}
else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
InstructionUtils.checkNumFields(parts, 17);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[16]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[4]));
stride.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
padding.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
input_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
filter_shape.add(new CPOperand(parts[15]));
return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride,
padding, input_shape, filter_shape, Double.parseDouble(parts[17]));
}
else if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("avgpooling")) {
InstructionUtils.checkNumFields(parts, 15);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[14]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[2]));
stride.add(new CPOperand(parts[3]));
padding.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
input_shape.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
filter_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
return new DnnGPUInstruction(in1, null, out, opcode, str, stride,
padding, input_shape, filter_shape, Double.parseDouble(parts[15]));
}
else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) {
InstructionUtils.checkNumFields(parts, 4);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
return new DnnGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4]));
}
else if (opcode.equalsIgnoreCase("channel_sums")) {
InstructionUtils.checkNumFields(parts, 4);
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
return new DnnGPUInstruction(in, in2, in3, out, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("update_nesterov_x")) {
InstructionUtils.checkNumFields(parts, 5);
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("lstm")) {
InstructionUtils.checkNumFields(parts, 8);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand in5 = new CPOperand(parts[5]);
CPOperand in6 = new CPOperand(parts[6]);
CPOperand out = new CPOperand(parts[7]);
CPOperand out2 = new CPOperand(parts[8]);
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) {
InstructionUtils.checkNumFields(parts, 13);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // scale
CPOperand in3 = new CPOperand(parts[3]); // bias
CPOperand in4 = new CPOperand(parts[4]); // runningMean
CPOperand in5 = new CPOperand(parts[5]); // runningVar
CPOperand in6 = new CPOperand(parts[6]); // mode
CPOperand in7 = new CPOperand(parts[7]); // epsilon
CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor
CPOperand out = new CPOperand(parts[9], Types.ValueType.FP64, Types.DataType.MATRIX); // ret
CPOperand out2 = new CPOperand(parts[10]); // retRunningMean
CPOperand out3 = new CPOperand(parts[11]); // retRunningVar
CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean
CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
InstructionUtils.checkNumFields(parts, 9);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // dout
CPOperand in3 = new CPOperand(parts[3]); // scale
CPOperand in4 = new CPOperand(parts[4]); // epsilon
CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean
CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance
CPOperand out = new CPOperand(parts[7]); // dX
CPOperand out2 = new CPOperand(parts[8]); // dScale
CPOperand out3 = new CPOperand(parts[9]); // dBias
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
InstructionUtils.checkNumFields(parts, 7);
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand in5 = new CPOperand(parts[5]);
CPOperand in6 = new CPOperand(parts[6]);
CPOperand out = new CPOperand(parts[7]);
return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
InstructionUtils.checkNumFields(parts, 12);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // gamma
CPOperand in3 = new CPOperand(parts[3]); // beta
CPOperand in4 = new CPOperand(parts[4]); // ema_mean
CPOperand in5 = new CPOperand(parts[5]); // ema_var
CPOperand in6 = new CPOperand(parts[6]); // eps
CPOperand in7 = new CPOperand(parts[7]); // mu
CPOperand out = new CPOperand(parts[8]); // out
CPOperand out2 = new CPOperand(parts[9]); // ema_mean_upd
CPOperand out3 = new CPOperand(parts[10]); // ema_var_upd
CPOperand out4 = new CPOperand(parts[11]); // cache_mean
CPOperand out5 = new CPOperand(parts[12]); // cache_inv_var
return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0);
}
else {
throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str);
}
}