public static DnnGPUInstruction parseInstruction()

in src/main/java/org/apache/sysds/runtime/instructions/gpu/DnnGPUInstruction.java [166:377]


	public static DnnGPUInstruction parseInstruction(String str) {
		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
		String opcode = parts[0];
		if( ( opcode.equalsIgnoreCase("conv2d")
			 || opcode.equalsIgnoreCase("conv2d_backward_filter")
			 || opcode.equalsIgnoreCase("conv2d_backward_data")) ) {
			InstructionUtils.checkNumFields(parts, 16);
			CPOperand in1 = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand out = new CPOperand(parts[15]);
			ArrayList<CPOperand> stride = new ArrayList<>();
			ArrayList<CPOperand> padding = new ArrayList<>();
			ArrayList<CPOperand> input_shape = new ArrayList<>();
			ArrayList<CPOperand> filter_shape = new ArrayList<>();
			stride.add(new CPOperand(parts[3]));
			stride.add(new CPOperand(parts[4]));
			padding.add(new CPOperand(parts[5]));
			padding.add(new CPOperand(parts[6]));
			input_shape.add(new CPOperand(parts[7]));
			input_shape.add(new CPOperand(parts[8]));
			input_shape.add(new CPOperand(parts[9]));
			input_shape.add(new CPOperand(parts[10]));
			filter_shape.add(new CPOperand(parts[11]));
			filter_shape.add(new CPOperand(parts[12]));
			filter_shape.add(new CPOperand(parts[13]));
			filter_shape.add(new CPOperand(parts[14]));

			return new DnnGPUInstruction(in1, in2, out, opcode, str, stride,
					padding, input_shape, filter_shape, Double.parseDouble(parts[16]));
		}
		else if( opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("avgpooling_backward") ) {
			boolean withMaxPoolOut = false;
			if(parts.length == 18) {
				withMaxPoolOut = true;
			}
			else
				InstructionUtils.checkNumFields(parts, 16);
			CPOperand in1 = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand in3 = withMaxPoolOut ? new CPOperand(parts[15]) : null;
			CPOperand out = withMaxPoolOut ? new CPOperand(parts[16]) : new CPOperand(parts[15]);
			double memBudget = withMaxPoolOut ? Double.parseDouble(parts[17]) : Double.parseDouble(parts[16]);
		
			ArrayList<CPOperand> stride = new ArrayList<>();
			ArrayList<CPOperand> padding = new ArrayList<>();
			ArrayList<CPOperand> input_shape = new ArrayList<>();
			ArrayList<CPOperand> filter_shape = new ArrayList<>();
			stride.add(new CPOperand(parts[3]));
			stride.add(new CPOperand(parts[4]));
			padding.add(new CPOperand(parts[5]));
			padding.add(new CPOperand(parts[6]));
			input_shape.add(new CPOperand(parts[7]));
			input_shape.add(new CPOperand(parts[8]));
			input_shape.add(new CPOperand(parts[9]));
			input_shape.add(new CPOperand(parts[10]));
			filter_shape.add(new CPOperand(parts[11]));
			filter_shape.add(new CPOperand(parts[12]));
			filter_shape.add(new CPOperand(parts[13]));
			filter_shape.add(new CPOperand(parts[14]));

			return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride,
					padding, input_shape, filter_shape, memBudget);
		}
		else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
			InstructionUtils.checkNumFields(parts, 17);
			CPOperand in1 = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand in3 = new CPOperand(parts[3]);
			CPOperand out = new CPOperand(parts[16]);
		
			ArrayList<CPOperand> stride = new ArrayList<>();
			ArrayList<CPOperand> padding = new ArrayList<>();
			ArrayList<CPOperand> input_shape = new ArrayList<>();
			ArrayList<CPOperand> filter_shape = new ArrayList<>();
			stride.add(new CPOperand(parts[4]));
			stride.add(new CPOperand(parts[5]));
			padding.add(new CPOperand(parts[6]));
			padding.add(new CPOperand(parts[7]));
			input_shape.add(new CPOperand(parts[8]));
			input_shape.add(new CPOperand(parts[9]));
			input_shape.add(new CPOperand(parts[10]));
			input_shape.add(new CPOperand(parts[11]));
			filter_shape.add(new CPOperand(parts[12]));
			filter_shape.add(new CPOperand(parts[13]));
			filter_shape.add(new CPOperand(parts[14]));
			filter_shape.add(new CPOperand(parts[15]));

			return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride,
					padding, input_shape, filter_shape, Double.parseDouble(parts[17]));
		}
		else if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("avgpooling")) {
			InstructionUtils.checkNumFields(parts, 15);
			CPOperand in1 = new CPOperand(parts[1]);
			CPOperand out = new CPOperand(parts[14]);
		
			ArrayList<CPOperand> stride = new ArrayList<>();
			ArrayList<CPOperand> padding = new ArrayList<>();
			ArrayList<CPOperand> input_shape = new ArrayList<>();
			ArrayList<CPOperand> filter_shape = new ArrayList<>();
			stride.add(new CPOperand(parts[2]));
			stride.add(new CPOperand(parts[3]));
			padding.add(new CPOperand(parts[4]));
			padding.add(new CPOperand(parts[5]));
			input_shape.add(new CPOperand(parts[6]));
			input_shape.add(new CPOperand(parts[7]));
			input_shape.add(new CPOperand(parts[8]));
			input_shape.add(new CPOperand(parts[9]));
			filter_shape.add(new CPOperand(parts[10]));
			filter_shape.add(new CPOperand(parts[11]));
			filter_shape.add(new CPOperand(parts[12]));
			filter_shape.add(new CPOperand(parts[13]));

			return new DnnGPUInstruction(in1, null, out, opcode, str, stride,
					padding, input_shape, filter_shape, Double.parseDouble(parts[15]));
		}
		else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply")  ) {
			InstructionUtils.checkNumFields(parts, 4);
			CPOperand in1 = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand out = new CPOperand(parts[3]);
			return new DnnGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4]));
		}
		else if (opcode.equalsIgnoreCase("channel_sums")) {
			InstructionUtils.checkNumFields(parts, 4);
			CPOperand in = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand in3 = new CPOperand(parts[3]);
			CPOperand out = new CPOperand(parts[4]);
			return new DnnGPUInstruction(in, in2, in3, out, opcode, str, 0);
		}
		else if (opcode.equalsIgnoreCase("update_nesterov_x")) {
			InstructionUtils.checkNumFields(parts, 5);
			CPOperand in = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand in3 = new CPOperand(parts[3]);
			CPOperand in4 = new CPOperand(parts[4]);
			CPOperand out = new CPOperand(parts[5]);
			return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0);
		}
		else if (opcode.equalsIgnoreCase("lstm")) {
			InstructionUtils.checkNumFields(parts, 8);
			CPOperand in1 = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand in3 = new CPOperand(parts[3]);
			CPOperand in4 = new CPOperand(parts[4]);
			CPOperand in5 = new CPOperand(parts[5]);
			CPOperand in6 = new CPOperand(parts[6]);
			CPOperand out = new CPOperand(parts[7]);
			CPOperand out2 = new CPOperand(parts[8]);
			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0);
		}
		else if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) {
			InstructionUtils.checkNumFields(parts, 13);
			CPOperand in1 = new CPOperand(parts[1]); // image
			CPOperand in2 = new CPOperand(parts[2]); // scale
			CPOperand in3 = new CPOperand(parts[3]); // bias
			CPOperand in4 = new CPOperand(parts[4]); // runningMean
			CPOperand in5 = new CPOperand(parts[5]); // runningVar
			CPOperand in6 = new CPOperand(parts[6]); // mode
			CPOperand in7 = new CPOperand(parts[7]); // epsilon
			CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor
			CPOperand out = new CPOperand(parts[9], Types.ValueType.FP64, Types.DataType.MATRIX);  // ret
			CPOperand out2 = new CPOperand(parts[10]); // retRunningMean
			CPOperand out3 = new CPOperand(parts[11]); // retRunningVar
			CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean
			CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance
			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
		}
		else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
			InstructionUtils.checkNumFields(parts, 9);
			CPOperand in1 = new CPOperand(parts[1]); // image
			CPOperand in2 = new CPOperand(parts[2]); // dout
			CPOperand in3 = new CPOperand(parts[3]); // scale
			CPOperand in4 = new CPOperand(parts[4]); // epsilon
			CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean
			CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance
			CPOperand out = new CPOperand(parts[7]);  // dX
			CPOperand out2 = new CPOperand(parts[8]); // dScale
			CPOperand out3 = new CPOperand(parts[9]); // dBias
			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0);
		}
		else if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
			InstructionUtils.checkNumFields(parts, 7);
			CPOperand in = new CPOperand(parts[1]);
			CPOperand in2 = new CPOperand(parts[2]);
			CPOperand in3 = new CPOperand(parts[3]);
			CPOperand in4 = new CPOperand(parts[4]);
			CPOperand in5 = new CPOperand(parts[5]);
			CPOperand in6 = new CPOperand(parts[6]);
			CPOperand out = new CPOperand(parts[7]);
			return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0);
		}
		else if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
			InstructionUtils.checkNumFields(parts, 12);
			CPOperand in1 = new CPOperand(parts[1]); // image
			CPOperand in2 = new CPOperand(parts[2]); // gamma
			CPOperand in3 = new CPOperand(parts[3]); // beta
			CPOperand in4 = new CPOperand(parts[4]); // ema_mean
			CPOperand in5 = new CPOperand(parts[5]); // ema_var
			CPOperand in6 = new CPOperand(parts[6]); // eps
			CPOperand in7 = new CPOperand(parts[7]); // mu
			CPOperand out = new CPOperand(parts[8]);  // out
			CPOperand out2 = new CPOperand(parts[9]); // ema_mean_upd
			CPOperand out3 = new CPOperand(parts[10]); // ema_var_upd
			CPOperand out4 = new CPOperand(parts[11]); // cache_mean
			CPOperand out5 = new CPOperand(parts[12]); // cache_inv_var
			return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0);
		}
		else {
			throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str);	
		}
	}