in src/main/java/org/apache/sysds/lops/compile/Dag.java [467:725]
private void generateControlProgramJobs(List<Lop> execNodes,
List<Instruction> inst, List<Instruction> writeInst, List<Instruction> deleteInst) {
// nodes to be deleted from execnodes
ArrayList<Lop> markedNodes = new ArrayList<>();
// variable names to be deleted
ArrayList<String> var_deletions = new ArrayList<>();
HashMap<String, Lop> var_deletionsLineNum = new HashMap<>();
boolean doRmVar = false;
for (int i = 0; i < execNodes.size(); i++) {
Lop node = execNodes.get(i);
doRmVar = false;
// mark input scalar read nodes for deletion
if (node.isDataExecLocation()
&& ((Data) node).getOperationType().isRead()
&& ((Data) node).getDataType() == DataType.SCALAR
&& node.getOutputParameters().getFile_name() == null ) {
markedNodes.add(node);
continue;
}
// output scalar instructions and mark nodes for deletion
if (!node.isDataExecLocation()) {
if (node.getDataType() == DataType.SCALAR) {
// Output from lops with SCALAR data type must
// go into Temporary Variables (Var0, Var1, etc.)
NodeOutput out = setupNodeOutputs(node, ExecType.CP, false, false);
inst.addAll(out.getPreInstructions()); // dummy
deleteInst.addAll(out.getLastInstructions());
} else {
// Output from lops with non-SCALAR data type must
// go into Temporary Files (temp0, temp1, etc.)
NodeOutput out = setupNodeOutputs(node, ExecType.CP, false, false);
inst.addAll(out.getPreInstructions());
boolean hasTransientWriteParent = false;
for ( Lop parent : node.getOutputs() ) {
if ( parent.isDataExecLocation()
&& ((Data)parent).getOperationType().isWrite()
&& ((Data)parent).getOperationType().isTransient() ) {
hasTransientWriteParent = true;
break;
}
}
if ( !hasTransientWriteParent ) {
deleteInst.addAll(out.getLastInstructions());
}
else {
var_deletions.add(node.getOutputParameters().getLabel());
var_deletionsLineNum.put(node.getOutputParameters().getLabel(), node);
}
}
String inst_string = "";
// Lops with arbitrary number of inputs (ParameterizedBuiltin, GroupedAggregate, DataGen)
// are handled separately, by simply passing ONLY the output variable to getInstructions()
if (node.getType() == Lop.Type.ParameterizedBuiltin
|| node.getType() == Lop.Type.GroupedAgg
|| node.getType() == Lop.Type.DataGen){
inst_string = node.getInstructions(node.getOutputParameters().getLabel());
}
// Lops with arbitrary number of inputs and outputs are handled
// separately as well by passing arrays of inputs and outputs
else if ( node.getType() == Lop.Type.FunctionCallCP )
{
String[] inputs = new String[node.getInputs().size()];
String[] outputs = new String[node.getOutputs().size()];
int count = 0;
for( Lop in : node.getInputs() )
inputs[count++] = in.getOutputParameters().getLabel();
count = 0;
for( Lop out : node.getOutputs() )
outputs[count++] = out.getOutputParameters().getLabel();
inst_string = node.getInstructions(inputs, outputs);
}
else if (node.getType() == Lop.Type.Nary) {
String[] inputs = new String[node.getInputs().size()];
int count = 0;
for( Lop in : node.getInputs() )
inputs[count++] = in.getOutputParameters().getLabel();
inst_string = node.getInstructions(inputs,
node.getOutputParameters().getLabel());
}
else {
if ( node.getInputs().isEmpty() ) {
// currently, such a case exists only for Rand lop
inst_string = node.getInstructions(node.getOutputParameters().getLabel());
}
else if (node.getInputs().size() == 1) {
inst_string = node.getInstructions(node.getInputs()
.get(0).getOutputParameters().getLabel(),
node.getOutputParameters().getLabel());
}
else if (node.getInputs().size() == 2) {
inst_string = node.getInstructions(
node.getInputs().get(0).getOutputParameters().getLabel(),
node.getInputs().get(1).getOutputParameters().getLabel(),
node.getOutputParameters().getLabel());
}
else if (node.getInputs().size() == 3 || node.getType() == Type.Ctable) {
inst_string = node.getInstructions(
node.getInputs().get(0).getOutputParameters().getLabel(),
node.getInputs().get(1).getOutputParameters().getLabel(),
node.getInputs().get(2).getOutputParameters().getLabel(),
node.getOutputParameters().getLabel());
}
else if (node.getInputs().size() == 4) {
inst_string = node.getInstructions(
node.getInputs().get(0).getOutputParameters().getLabel(),
node.getInputs().get(1).getOutputParameters().getLabel(),
node.getInputs().get(2).getOutputParameters().getLabel(),
node.getInputs().get(3).getOutputParameters().getLabel(),
node.getOutputParameters().getLabel());
}
else if (node.getInputs().size() == 5) {
inst_string = node.getInstructions(
node.getInputs().get(0).getOutputParameters().getLabel(),
node.getInputs().get(1).getOutputParameters().getLabel(),
node.getInputs().get(2).getOutputParameters().getLabel(),
node.getInputs().get(3).getOutputParameters().getLabel(),
node.getInputs().get(4).getOutputParameters().getLabel(),
node.getOutputParameters().getLabel());
}
else if (node.getInputs().size() == 6) {
inst_string = node.getInstructions(
node.getInputs().get(0).getOutputParameters().getLabel(),
node.getInputs().get(1).getOutputParameters().getLabel(),
node.getInputs().get(2).getOutputParameters().getLabel(),
node.getInputs().get(3).getOutputParameters().getLabel(),
node.getInputs().get(4).getOutputParameters().getLabel(),
node.getInputs().get(5).getOutputParameters().getLabel(),
node.getOutputParameters().getLabel());
}
else if (node.getInputs().size() == 7) {
inst_string = node.getInstructions(
node.getInputs().get(0).getOutputParameters().getLabel(),
node.getInputs().get(1).getOutputParameters().getLabel(),
node.getInputs().get(2).getOutputParameters().getLabel(),
node.getInputs().get(3).getOutputParameters().getLabel(),
node.getInputs().get(4).getOutputParameters().getLabel(),
node.getInputs().get(5).getOutputParameters().getLabel(),
node.getInputs().get(6).getOutputParameters().getLabel(),
node.getOutputParameters().getLabel());
}
else {
String[] inputs = new String[node.getInputs().size()];
for( int j=0; j<node.getInputs().size(); j++ )
inputs[j] = node.getInputs().get(j).getOutputParameters().getLabel();
inst_string = node.getInstructions(inputs,
node.getOutputParameters().getLabel());
}
}
try {
if( LOG.isTraceEnabled() )
LOG.trace("Generating instruction - "+ inst_string);
Instruction currInstr = InstructionParser.parseSingleInstruction(inst_string);
if(currInstr == null) {
throw new LopsException("Error parsing the instruction:" + inst_string);
}
if (node._beginLine != 0)
currInstr.setLocation(node);
else if ( !node.getOutputs().isEmpty() )
currInstr.setLocation(node.getOutputs().get(0));
else if ( !node.getInputs().isEmpty() )
currInstr.setLocation(node.getInputs().get(0));
inst.add(currInstr);
} catch (Exception e) {
throw new LopsException(node.printErrorLocation() + "Problem generating simple inst - "
+ inst_string, e);
}
markedNodes.add(node);
doRmVar = true;
}
else if (node.isDataExecLocation() ) {
Data dnode = (Data)node;
OpOpData op = dnode.getOperationType();
if ( op.isWrite() ) {
NodeOutput out = null;
out = setupNodeOutputs(node, ExecType.CP, false, false);
if ( dnode.getDataType() == DataType.SCALAR ) {
// processing is same for both transient and persistent scalar writes
writeInst.addAll(out.getLastInstructions());
doRmVar = false;
}
else {
// setupNodeOutputs() handles both transient and persistent matrix writes
if ( dnode.getOperationType().isTransient() ) {
deleteInst.addAll(out.getLastInstructions());
doRmVar = false;
}
else {
// In case of persistent write lop, write instruction will be generated
// and that instruction must be added to <code>inst</code> so that it gets
// executed immediately. If it is added to <code>deleteInst</code> then it
// gets executed at the end of program block's execution
inst.addAll(out.getLastInstructions());
doRmVar = true;
}
}
markedNodes.add(node);
}
else {
// generate a temp label to hold the value that is read from HDFS
if ( node.getDataType() == DataType.SCALAR ) {
node.getOutputParameters().setLabel(Lop.SCALAR_VAR_NAME_PREFIX + var_index.getNextID());
String io_inst = node.getInstructions(node.getOutputParameters().getLabel(),
node.getOutputParameters().getFile_name());
CPInstruction currInstr = CPInstructionParser.parseSingleInstruction(io_inst);
currInstr.setLocation(node);
inst.add(currInstr);
Instruction tempInstr = VariableCPInstruction.prepareRemoveInstruction(node.getOutputParameters().getLabel());
tempInstr.setLocation(node);
deleteInst.add(tempInstr);
}
else {
throw new LopsException("Matrix READs are not handled in CP yet!");
}
markedNodes.add(node);
doRmVar = true;
}
}
// see if rmvar instructions can be generated for node's inputs
if(doRmVar)
processConsumersForInputs(node, inst, deleteInst);
doRmVar = false;
}
for ( String var : var_deletions ) {
Instruction rmInst = VariableCPInstruction.prepareRemoveInstruction(var);
if( LOG.isTraceEnabled() )
LOG.trace(" Adding var_deletions: " + rmInst.toString());
rmInst.setLocation(var_deletionsLineNum.get(var));
deleteInst.add(rmInst);
}
// delete all marked nodes
for ( Lop node : markedNodes ) {
execNodes.remove(node);
}
}