in src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java [143:323]
public static DnnCPInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
if (opcode.equalsIgnoreCase(Opcodes.MAXPOOLING.toString()) || opcode.equalsIgnoreCase(Opcodes.RELU_MAXPOOLING.toString()) ||
opcode.equalsIgnoreCase(Opcodes.AVGPOOLING.toString())) {
InstructionUtils.checkNumFields(parts, 16);
// stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
CPOperand in = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[14]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[2]));
stride.add(new CPOperand(parts[3]));
padding.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
input_shape.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
filter_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
int k = Integer.parseInt(parts[15]);
return new DnnCPInstruction(in, out, opcode, str, stride,
padding, input_shape, filter_shape, k, Double.parseDouble(parts[16]));
}
else if (opcode.equalsIgnoreCase(Opcodes.MAXPOOLING_BACKWARD.toString()) || opcode.equalsIgnoreCase(Opcodes.RELU_MAXPOOLING_BACKWARD.toString())
|| opcode.equalsIgnoreCase(Opcodes.AVGPOOLING_BACKWARD.toString())
|| opcode.equalsIgnoreCase(Opcodes.CONV2D.toString())
|| opcode.equalsIgnoreCase(Opcodes.CONV2D_BACKWARD_FILTER.toString())
|| opcode.equalsIgnoreCase(Opcodes.CONV2D_BACKWARD_DATA.toString())) {
InstructionUtils.checkNumFields(parts, 17);
// dout, stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[15]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[3]));
stride.add(new CPOperand(parts[4]));
padding.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
input_shape.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
filter_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
int k = Integer.parseInt(parts[16]);
return new DnnCPInstruction(in, in2, out, opcode, str, stride,
padding, input_shape, filter_shape, k, Double.parseDouble(parts[17]));
}
else if (opcode.equalsIgnoreCase(Opcodes.CONV2D_BIAS_ADD.toString())) {
InstructionUtils.checkNumFields(parts, 18);
// dout, stride1, stride2, padding1, padding2
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4, k
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[16]);
ArrayList<CPOperand> stride = new ArrayList<>();
ArrayList<CPOperand> padding = new ArrayList<>();
ArrayList<CPOperand> input_shape = new ArrayList<>();
ArrayList<CPOperand> filter_shape = new ArrayList<>();
stride.add(new CPOperand(parts[4]));
stride.add(new CPOperand(parts[5]));
padding.add(new CPOperand(parts[6]));
padding.add(new CPOperand(parts[7]));
input_shape.add(new CPOperand(parts[8]));
input_shape.add(new CPOperand(parts[9]));
input_shape.add(new CPOperand(parts[10]));
input_shape.add(new CPOperand(parts[11]));
filter_shape.add(new CPOperand(parts[12]));
filter_shape.add(new CPOperand(parts[13]));
filter_shape.add(new CPOperand(parts[14]));
filter_shape.add(new CPOperand(parts[15]));
int k = Integer.parseInt(parts[17]);
return new DnnCPInstruction(in, in2, in3, out, opcode, str, stride,
padding, input_shape, filter_shape, k, Double.parseDouble(parts[18]));
}
else if (opcode.equalsIgnoreCase(Opcodes.BIAS_ADD.toString()) || opcode.equals(Opcodes.RELU_BACKWARD.toString()) || opcode.equalsIgnoreCase(Opcodes.BIAS_MULTIPLY.toString()) ) {
InstructionUtils.checkNumFields(parts, 5);
CPOperand in = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
int k = Integer.parseInt(parts[4]);
return new DnnCPInstruction(in, in2, out, opcode, str, k, Double.parseDouble(parts[5]));
}
else if (opcode.equalsIgnoreCase(Opcodes.BATCH_NORM2D.toString())) {
InstructionUtils.checkNumFields(parts, 14);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // scale
CPOperand in3 = new CPOperand(parts[3]); // bias
CPOperand in4 = new CPOperand(parts[4]); // runningMean
CPOperand in5 = new CPOperand(parts[5]); // runningVar
CPOperand in6 = new CPOperand(parts[6]); // mode
CPOperand in7 = new CPOperand(parts[7]); // epsilon
CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor
CPOperand out = new CPOperand(parts[9]); // ret
CPOperand out2 = new CPOperand(parts[10]); // retRunningMean
CPOperand out3 = new CPOperand(parts[11]); // retRunningVar
CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean
CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance
// int threads = Integer.parseInt(parts[14]);
return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase(Opcodes.BATCH_NORM2D_BACKWARD.toString())) {
InstructionUtils.checkNumFields(parts, 10);
CPOperand in1 = new CPOperand(parts[1]); // image
CPOperand in2 = new CPOperand(parts[2]); // dout
CPOperand in3 = new CPOperand(parts[3]); // scale
CPOperand in4 = new CPOperand(parts[4]); // epsilon
CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean
CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance
CPOperand out = new CPOperand(parts[7]); // dX
CPOperand out2 = new CPOperand(parts[8]); // dScale
CPOperand out3 = new CPOperand(parts[9]); // dBias
// int threads = Integer.parseInt(parts[10]);
return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0);
}
else if (opcode.equalsIgnoreCase(Opcodes.LSTM.toString())) {
InstructionUtils.checkNumFields(parts, 12);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand in5 = new CPOperand(parts[5]);
CPOperand in6 = new CPOperand(parts[6]);
CPOperand out1 = new CPOperand(parts[7]);
CPOperand out2 = new CPOperand(parts[8]);
CPOperand out3 = new CPOperand(parts[9]);
CPOperand out4 = new CPOperand(parts[10]);
CPOperand out5 = new CPOperand(parts[11]);
// int threads = Integer.parseInt(parts[12]);
return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out1, out2, out3, out4, out5, opcode, str, 0);
} if(opcode.equalsIgnoreCase(Opcodes.LSTM_BACKWARD.toString())){
InstructionUtils.checkNumFields(parts, 17);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand in5 = new CPOperand(parts[5]);
CPOperand in6 = new CPOperand(parts[6]);
CPOperand in7 = new CPOperand(parts[7]);
CPOperand in8 = new CPOperand(parts[8]);
CPOperand in9 = new CPOperand(parts[9]);
CPOperand in10 = new CPOperand(parts[10]);
CPOperand in11 = new CPOperand(parts[11]);
CPOperand out1 = new CPOperand(parts[12]);
CPOperand out2 = new CPOperand(parts[13]);
CPOperand out3 = new CPOperand(parts[14]);
CPOperand out4 = new CPOperand(parts[15]);
CPOperand out5 = new CPOperand(parts[16]);
// int threads = Integer.parseInt(parts[17]);
return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, in9, in10, in11, out1, out2, out3, out4, out5, opcode, str, 0);
}
else {
throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str);
}
}