in src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java [592:865]
private static void rConstructHops(LineageItem item, Map<Long, Hop> operands, Map<String, Hop> partDagRoots, Program prog)
{
if (item.isVisited())
return;
//recursively process children (ordering by data dependencies)
if (!item.isLeaf())
for (LineageItem c : item.getInputs())
rConstructHops(c, operands, partDagRoots, prog);
//process current lineage item
//NOTE: we generate instructions from hops (but without rewrites) to automatically
//handle execution types, rmvar instructions, and rewiring of inputs/outputs
switch (item.getType()) {
case Creation: {
if (item.getData().startsWith(LPLACEHOLDER)) {
long phId = Long.parseLong(item.getData().substring(3));
Hop input = operands.get(phId);
operands.remove(phId);
// Replace the placeholders with TReads
operands.put(item.getId(), input); // order preserving
break;
}
Instruction inst = InstructionParser.parseSingleInstruction(item.getData());
if (inst instanceof DataGenCPInstruction) {
DataGenCPInstruction rand = (DataGenCPInstruction) inst;
HashMap<String, Hop> params = new HashMap<>();
if( rand.getOpcode().equals("rand") ) {
if( rand.output.getDataType() == DataType.TENSOR)
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
else {
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
}
params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
}
else if( rand.getOpcode().equals("seq") ) {
params.put(Statement.SEQ_FROM, new LiteralOp(rand.getFrom()));
params.put(Statement.SEQ_TO, new LiteralOp(rand.getTo()));
params.put(Statement.SEQ_INCR, new LiteralOp(rand.getIncr()));
}
Hop datagen = new DataGenOp(OpOpDG.valueOf(rand.getOpcode().toUpperCase()),
new DataIdentifier("tmp"), params);
datagen.setBlocksize(rand.getBlocksize());
operands.put(item.getId(), datagen);
} else if (inst instanceof VariableCPInstruction
&& ((VariableCPInstruction) inst).isCreateVariable()) {
String parts[] = InstructionUtils.getInstructionPartsWithValueType(inst.toString());
DataType dt = DataType.valueOf(parts[4]);
ValueType vt = dt == DataType.MATRIX ? ValueType.FP64 : ValueType.STRING;
HashMap<String, Hop> params = new HashMap<>();
params.put(DataExpression.IO_FILENAME, new LiteralOp(parts[2]));
params.put(DataExpression.READROWPARAM, new LiteralOp(Long.parseLong(parts[6])));
params.put(DataExpression.READCOLPARAM, new LiteralOp(Long.parseLong(parts[7])));
params.put(DataExpression.READNNZPARAM, new LiteralOp(Long.parseLong(parts[8])));
params.put(DataExpression.FORMAT_TYPE, new LiteralOp(parts[5]));
DataOp pread = new DataOp(parts[1].substring(5), dt, vt, OpOpData.PERSISTENTREAD, params);
pread.setFileName(parts[2]);
operands.put(item.getId(), pread);
}
else if (inst instanceof RandSPInstruction) {
RandSPInstruction rand = (RandSPInstruction) inst;
HashMap<String, Hop> params = new HashMap<>();
if (rand.output.getDataType() == DataType.TENSOR)
params.put(DataExpression.RAND_DIMS, new LiteralOp(rand.getDims()));
else {
params.put(DataExpression.RAND_ROWS, new LiteralOp(rand.getRows()));
params.put(DataExpression.RAND_COLS, new LiteralOp(rand.getCols()));
}
params.put(DataExpression.RAND_MIN, new LiteralOp(rand.getMinValue()));
params.put(DataExpression.RAND_MAX, new LiteralOp(rand.getMaxValue()));
params.put(DataExpression.RAND_PDF, new LiteralOp(rand.getPdf()));
params.put(DataExpression.RAND_LAMBDA, new LiteralOp(rand.getPdfParams()));
params.put(DataExpression.RAND_SPARSITY, new LiteralOp(rand.getSparsity()));
params.put(DataExpression.RAND_SEED, new LiteralOp(rand.getSeed()));
Hop datagen = new DataGenOp(OpOpDG.RAND, new DataIdentifier("tmp"), params);
datagen.setBlocksize(rand.getBlocksize());
operands.put(item.getId(), datagen);
}
break;
}
case Dedup: {
// Create function call for each dedup entry
String[] parts = item.getOpcode().split(LineageDedupUtils.DEDUP_DELIM); //e.g. dedup_R_SB13_0
String name = parts[2] + parts[1] + parts[3]; //loopId + outVar + pathId
List<Hop> finputs = Arrays.stream(item.getInputs())
.map(inp -> operands.get(inp.getId())).collect(Collectors.toList());
String[] inputNames = new String[item.getInputs().length];
for (int i=0; i<item.getInputs().length; i++)
inputNames[i] = LPLACEHOLDER + i; //e.g. IN#0, IN#1
Hop funcOp = new FunctionOp(FunctionType.DML, DMLProgram.DEFAULT_NAMESPACE,
name, inputNames, finputs, new String[] {parts[1]}, false);
// Cut the Hop dag after function calls
partDagRoots.put(parts[1], funcOp);
// Compile the dag and save
constructBasicBlock(partDagRoots, parts[1], prog);
// Construct a Hop dag for the function body from the dedup patch, and compile
Hop output = constructHopsDedupPatch(parts, inputNames, finputs, prog);
// Create a TRead on the function o/p as a leaf for the next Hop dag
// Use the function body root/return hop to propagate right data type
operands.put(item.getId(), HopRewriteUtils.createTransientRead(parts[1], output));
break;
}
case Instruction: {
InstructionType ctype = InstructionUtils.getCPTypeByOpcode(item.getOpcode());
InstructionType stype = InstructionUtils.getSPTypeByOpcode(item.getOpcode());
if (ctype != null) {
switch (ctype) {
case AggregateUnary: {
Hop input = operands.get(item.getInputs()[0].getId());
Hop aggunary = InstructionUtils.isUnaryMetadata(item.getOpcode()) ?
HopRewriteUtils.createUnary(input, OpOp1.valueOfByOpcode(item.getOpcode())) :
HopRewriteUtils.createAggUnaryOp(input, item.getOpcode());
operands.put(item.getId(), aggunary);
break;
}
case AggregateBinary: {
Hop input1 = operands.get(item.getInputs()[0].getId());
Hop input2 = operands.get(item.getInputs()[1].getId());
Hop aggbinary = HopRewriteUtils.createMatrixMultiply(input1, input2);
operands.put(item.getId(), aggbinary);
break;
}
case AggregateTernary: {
Hop input1 = operands.get(item.getInputs()[0].getId());
Hop input2 = operands.get(item.getInputs()[1].getId());
Hop input3 = operands.get(item.getInputs()[2].getId());
Hop aggternary = HopRewriteUtils.createSum(
HopRewriteUtils.createBinary(
HopRewriteUtils.createBinary(input1, input2, OpOp2.MULT),
input3, OpOp2.MULT));
operands.put(item.getId(), aggternary);
break;
}
case Unary:
case Builtin: {
Hop input = operands.get(item.getInputs()[0].getId());
Hop unary = HopRewriteUtils.createUnary(input, item.getOpcode());
operands.put(item.getId(), unary);
break;
}
case Reorg: {
operands.put(item.getId(), HopRewriteUtils.createReorg(
operands.get(item.getInputs()[0].getId()), item.getOpcode()));
break;
}
case Reshape: {
ArrayList<Hop> inputs = new ArrayList<>();
for(int i=0; i<5; i++)
inputs.add(operands.get(item.getInputs()[i].getId()));
operands.put(item.getId(), HopRewriteUtils.createReorg(inputs, ReOrgOp.RESHAPE));
break;
}
case Binary: {
//handle special cases of binary operations
String opcode = (Opcodes.POW2.toString().equals(item.getOpcode())
|| Opcodes.MULT2.toString().equals(item.getOpcode())) ?
item.getOpcode().substring(0, 1) : item.getOpcode();
Hop input1 = operands.get(item.getInputs()[0].getId());
Hop input2 = operands.get(item.getInputs()[1].getId());
Hop binary = HopRewriteUtils.createBinary(input1, input2, opcode);
operands.put(item.getId(), binary);
break;
}
case Ternary: {
operands.put(item.getId(), HopRewriteUtils.createTernary(
operands.get(item.getInputs()[0].getId()),
operands.get(item.getInputs()[1].getId()),
operands.get(item.getInputs()[2].getId()), item.getOpcode()));
break;
}
case Ctable: { //e.g., ctable
if( item.getInputs().length==3 )
operands.put(item.getId(), HopRewriteUtils.createTernary(
operands.get(item.getInputs()[0].getId()),
operands.get(item.getInputs()[1].getId()),
operands.get(item.getInputs()[2].getId()), OpOp3.CTABLE));
else if( item.getInputs().length==5 )
operands.put(item.getId(), HopRewriteUtils.createTernary(
operands.get(item.getInputs()[0].getId()),
operands.get(item.getInputs()[1].getId()),
operands.get(item.getInputs()[2].getId()),
operands.get(item.getInputs()[3].getId()),
operands.get(item.getInputs()[4].getId()), OpOp3.CTABLE));
break;
}
case BuiltinNary: {
String opcode = item.getOpcode().equals(Opcodes.NP.toString()) ? "plus" : item.getOpcode();
operands.put(item.getId(), HopRewriteUtils.createNary(
OpOpN.valueOf(opcode.toUpperCase()), createNaryInputs(item, operands)));
break;
}
case ParameterizedBuiltin: {
operands.put(item.getId(), constructParameterizedBuiltinOp(item, operands));
break;
}
case MatrixIndexing: {
operands.put(item.getId(), constructIndexingOp(item, operands));
break;
}
case MMTSJ: {
//TODO handling of tsmm type left and right -> placement transpose
Hop input = operands.get(item.getInputs()[0].getId());
Hop aggunary = HopRewriteUtils.createMatrixMultiply(
HopRewriteUtils.createTranspose(input), input);
operands.put(item.getId(), aggunary);
break;
}
case Variable: {
if( item.getOpcode().startsWith("cast") )
operands.put(item.getId(), HopRewriteUtils.createUnary(
operands.get(item.getInputs()[0].getId()),
OpOp1.valueOfByOpcode(item.getOpcode())));
else //cpvar, write
operands.put(item.getId(), operands.get(item.getInputs()[0].getId()));
break;
}
default:
throw new DMLRuntimeException("Unsupported instruction "
+ "type: " + ctype.name() + " (" + item.getOpcode() + ").");
}
}
else if( stype != null ) {
switch(stype) {
case Reblock: {
Hop input = operands.get(item.getInputs()[0].getId());
input.setBlocksize(ConfigurationManager.getBlocksize());
input.setRequiresReblock(true);
operands.put(item.getId(), input);
break;
}
case Checkpoint: {
Hop input = operands.get(item.getInputs()[0].getId());
operands.put(item.getId(), input);
break;
}
case MatrixIndexing: {
operands.put(item.getId(), constructIndexingOp(item, operands));
break;
}
case GAppend: {
operands.put(item.getId(), HopRewriteUtils.createBinary(
operands.get(item.getInputs()[0].getId()),
operands.get(item.getInputs()[1].getId()), OpOp2.CBIND));
break;
}
default:
throw new DMLRuntimeException("Unsupported instruction "
+ "type: " + stype.name() + " (" + item.getOpcode() + ").");
}
}
else
throw new DMLRuntimeException("Unsupported instruction: " + item.getOpcode());
break;
}
case Literal: {
CPOperand op = new CPOperand(item.getData());
operands.put(item.getId(), ScalarObjectFactory
.createLiteralOp(op.getValueType(), op.getName()));
break;
}
}
item.setVisited();
}