in src/main/java/org/apache/sysds/parser/DMLTranslator.java [957:1316]
public void constructHops(StatementBlock sb) {
if (sb instanceof WhileStatementBlock) {
constructHopsForWhileControlBlock((WhileStatementBlock) sb);
return;
}
if (sb instanceof IfStatementBlock) {
constructHopsForIfControlBlock((IfStatementBlock) sb);
return;
}
if (sb instanceof ForStatementBlock) { //incl ParForStatementBlock
constructHopsForForControlBlock((ForStatementBlock) sb);
return;
}
if (sb instanceof FunctionStatementBlock) {
constructHopsForFunctionControlBlock((FunctionStatementBlock) sb);
return;
}
HashMap<String, Hop> ids = new HashMap<>();
ArrayList<Hop> output = new ArrayList<>();
VariableSet liveIn = sb.liveIn();
VariableSet liveOut = sb.liveOut();
VariableSet updated = sb._updated;
VariableSet gen = sb._gen;
VariableSet updatedLiveOut = new VariableSet();
// handle liveout variables that are updated --> target identifiers for Assignment
HashMap<String, Integer> liveOutToTemp = new HashMap<>();
for (int i = 0; i < sb.getNumStatements(); i++) {
Statement current = sb.getStatement(i);
if (current instanceof AssignmentStatement) {
AssignmentStatement as = (AssignmentStatement) current;
DataIdentifier target = as.getTarget();
if (target != null) {
if (liveOut.containsVariable(target.getName())) {
liveOutToTemp.put(target.getName(), Integer.valueOf(i));
}
}
}
if (current instanceof MultiAssignmentStatement) {
MultiAssignmentStatement mas = (MultiAssignmentStatement) current;
for (DataIdentifier target : mas.getTargetList()){
if (liveOut.containsVariable(target.getName())) {
liveOutToTemp.put(target.getName(), Integer.valueOf(i));
}
}
}
}
// only create transient read operations for variables either updated or read-before-update
// (i.e., from LV analysis, updated and gen sets)
if ( !liveIn.getVariables().values().isEmpty() ) {
for (String varName : liveIn.getVariables().keySet()) {
if (updated.containsVariable(varName) || gen.containsVariable(varName)){
DataIdentifier var = liveIn.getVariables().get(varName);
long actualDim1 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim1() : var.getDim1();
long actualDim2 = (var instanceof IndexedIdentifier) ? ((IndexedIdentifier)var).getOrigDim2() : var.getDim2();
DataOp read = new DataOp(var.getName(), var.getDataType(), var.getValueType(), OpOpData.TRANSIENTREAD, null, actualDim1, actualDim2, var.getNnz(), var.getBlocksize());
read.setParseInfo(var);
ids.put(varName, read);
}
}
}
for( int i = 0; i < sb.getNumStatements(); i++ ) {
Statement current = sb.getStatement(i);
if (current instanceof OutputStatement) {
OutputStatement os = (OutputStatement) current;
DataExpression source = os.getSource();
DataIdentifier target = os.getIdentifier();
//error handling unsupported indexing expression in write statement
if( target instanceof IndexedIdentifier ) {
throw new LanguageException(source.printErrorLocation()+": Unsupported indexing expression in write statement. " +
"Please, assign the right indexing result to a variable and write this variable.");
}
DataOp ae = (DataOp)processExpression(source, target, ids);
Expression fmtExpr = os.getExprParam(DataExpression.FORMAT_TYPE);
ae.setFileFormat((fmtExpr instanceof StringIdentifier) ?
Expression.convertFormatType(fmtExpr.toString()) : FileFormat.UNKNOWN);
if (ae.getDataType() == DataType.SCALAR ) {
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1);
}
else {
switch(ae.getFileFormat()) {
case TEXT:
case MM:
case CSV:
case LIBSVM:
case HDF5:
// write output in textcell format
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1);
break;
case BINARY:
case COMPRESSED:
case UNKNOWN:
// write output in binary block format
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ae.getBlocksize());
break;
case FEDERATED:
ae.setOutputParams(ae.getDim1(), ae.getDim2(), -1, ae.getUpdateType(), -1);
break;
default:
throw new LanguageException("Unrecognized file format: " + ae.getFileFormat());
}
}
output.add(ae);
}
if (current instanceof PrintStatement) {
DataIdentifier target = createTarget();
target.setDataType(DataType.SCALAR);
target.setValueType(ValueType.STRING);
target.setParseInfo(current);
PrintStatement ps = (PrintStatement) current;
PRINTTYPE ptype = ps.getType();
try {
if (ptype == PRINTTYPE.PRINT) {
OpOp1 op = OpOp1.PRINT;
Expression source = ps.getExpressions().get(0);
Hop ae = processExpression(source, target, ids);
Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
printHop.setParseInfo(current);
output.add(printHop);
}
else if (ptype == PRINTTYPE.ASSERT) {
OpOp1 op = OpOp1.ASSERT;
Expression source = ps.getExpressions().get(0);
Hop ae = processExpression(source, target, ids);
Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
printHop.setParseInfo(current);
output.add(printHop);
}
else if (ptype == PRINTTYPE.STOP) {
OpOp1 op = OpOp1.STOP;
Expression source = ps.getExpressions().get(0);
Hop ae = processExpression(source, target, ids);
Hop stopHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
stopHop.setParseInfo(current);
output.add(stopHop);
sb.setSplitDag(true); //avoid merge
} else if (ptype == PRINTTYPE.PRINTF) {
List<Expression> expressions = ps.getExpressions();
Hop[] inHops = new Hop[expressions.size()];
// process the expressions (function parameters) that
// make up the printf-styled print statement
// into Hops so that these can be passed to the printf
// Hop (ie, MultipleOp) as input Hops
for (int j = 0; j < expressions.size(); j++) {
Hop inHop = processExpression(expressions.get(j), target, ids);
inHops[j] = inHop;
}
target.setValueType(ValueType.STRING);
Hop printfHop = new NaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOpN.PRINTF, inHops);
output.add(printfHop);
}
} catch (HopsException e) {
throw new LanguageException(e);
}
}
if (current instanceof AssignmentStatement) {
AssignmentStatement as = (AssignmentStatement) current;
DataIdentifier target = as.getTarget();
Expression source = as.getSource();
// CASE: regular assignment statement -- source is DML expression that is NOT user-defined or external function
if (!(source instanceof FunctionCallIdentifier)){
// CASE: target is regular data identifier
if (!(target instanceof IndexedIdentifier)) {
//process right hand side and accumulation
Hop ae = processExpression(source, target, ids);
if( as.isAccumulator() ) {
DataIdentifier accum = getAccumulatorData(liveIn, target.getName());
ae = HopRewriteUtils.createBinary(ids.get(target.getName()), ae, OpOp2.PLUS);
target.setProperties(accum.getOutput());
}
else
target.setProperties(source.getOutput());
if (source instanceof BuiltinFunctionExpression){
BuiltinFunctionExpression BuiltinSource = (BuiltinFunctionExpression)source;
if (BuiltinSource.getOpCode() == Builtins.TIME)
sb.setSplitDag(true);
}
ids.put(target.getName(), ae);
//add transient write if needed
Integer statementId = liveOutToTemp.get(target.getName());
if ((statementId != null) && (statementId.intValue() == i)) {
DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, OpOpData.TRANSIENTWRITE, null);
transientwrite.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ae.getBlocksize());
transientwrite.setParseInfo(target);
updatedLiveOut.addVariable(target.getName(), target);
output.add(transientwrite);
}
}
// CASE: target is indexed identifier (left-hand side indexed expression)
else {
Hop ae = processLeftIndexedExpression(source, (IndexedIdentifier)target, ids);
if( as.isAccumulator() ) {
DataIdentifier accum = getAccumulatorData(liveIn, target.getName());
Hop rix = processIndexingExpression((IndexedIdentifier)target, null, ids);
Hop rhs = processExpression(source, null, ids);
Hop binary = HopRewriteUtils.createBinary(rix, rhs, OpOp2.PLUS);
HopRewriteUtils.replaceChildReference(ae, ae.getInput(1), binary);
target.setProperties(accum.getOutput());
}
ids.put(target.getName(), ae);
// obtain origDim values BEFORE they are potentially updated during setProperties call
// (this is incorrect for LHS Indexing)
long origDim1 = ((IndexedIdentifier)target).getOrigDim1();
long origDim2 = ((IndexedIdentifier)target).getOrigDim2();
target.setProperties(source.getOutput());
((IndexedIdentifier)target).setOriginalDimensions(origDim1, origDim2);
// preserve data type matrix of any index identifier
// (required for scalar input to left indexing)
if( target.getDataType() != DataType.MATRIX ) {
target.setDataType(DataType.MATRIX);
target.setValueType(ValueType.FP64);
target.setBlocksize(ConfigurationManager.getBlocksize());
}
Integer statementId = liveOutToTemp.get(target.getName());
if ((statementId != null) && (statementId.intValue() == i)) {
DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, OpOpData.TRANSIENTWRITE, null);
transientwrite.setOutputParams(origDim1, origDim2, ae.getNnz(), ae.getUpdateType(), ae.getBlocksize());
transientwrite.setParseInfo(target);
updatedLiveOut.addVariable(target.getName(), target);
output.add(transientwrite);
}
}
}
else
{
//assignment, function call
FunctionCallIdentifier fci = (FunctionCallIdentifier) source;
FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName());
//error handling missing function
if (fsb == null) {
throw new LanguageException(source.printErrorLocation() + "function "
+ fci.getName() + " is undefined in namespace " + fci.getNamespace());
}
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
String fkey = DMLProgram.constructFunctionKey(fci.getNamespace(),fci.getName());
//error handling unsupported function call in indexing expression
if( target instanceof IndexedIdentifier ) {
throw new LanguageException("Unsupported function call to '"+fkey+"' in left indexing "
+ "expression. Please, assign the function output to a variable.");
}
//prepare function input names and inputs
List<String> inputNames = new ArrayList<>(fci.getParamExprs().stream()
.map(e -> e.getName()).collect(Collectors.toList()));
List<Hop> finputs = new ArrayList<>(fci.getParamExprs().stream()
.map(e -> processExpression(e.getExpr(), null, ids)).collect(Collectors.toList()));
//append default expression for missing arguments
appendDefaultArguments(fstmt, inputNames, finputs, ids);
//use function signature to obtain names for unnamed args
//(note: consistent parameters already checked for functions in general)
if( inputNames.stream().allMatch(n -> n==null) )
inputNames = fstmt._inputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
//create function op
String[] inputNames2 = inputNames.toArray(new String[0]);
FunctionType ftype = fsb.getFunctionOpType();
FunctionOp fcall = (target == null) ?
new FunctionOp(ftype, fci.getNamespace(), fci.getName(), inputNames2, finputs, new String[]{}, false) :
new FunctionOp(ftype, fci.getNamespace(), fci.getName(), inputNames2, finputs, new String[]{target.getName()}, false);
fcall.setParseInfo(fci);
output.add(fcall);
}
}
else if (current instanceof MultiAssignmentStatement) {
//multi-assignment, by definition a function call
MultiAssignmentStatement mas = (MultiAssignmentStatement) current;
Expression source = mas.getSource();
if ( source instanceof FunctionCallIdentifier ) {
FunctionCallIdentifier fci = (FunctionCallIdentifier) source;
FunctionStatementBlock fsb = this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName());
if (fsb == null){
throw new LanguageException(source.printErrorLocation() + "function "
+ fci.getName() + " is undefined in namespace " + fci.getNamespace());
}
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
//prepare function input names and inputs
List<String> inputNames = new ArrayList<>(fci.getParamExprs().stream()
.map(e -> e.getName()).collect(Collectors.toList()));
List<Hop> finputs = new ArrayList<>(fci.getParamExprs().stream()
.map(e -> processExpression(e.getExpr(), null, ids)).collect(Collectors.toList()));
//use function signature to obtain names for unnamed args
//(note: consistent parameters already checked for functions in general)
if( inputNames.stream().allMatch(n -> n==null) )
inputNames = fstmt._inputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
//append default expression for missing arguments
appendDefaultArguments(fstmt, inputNames, finputs, ids);
//create function op
String[] foutputs = mas.getTargetList().stream()
.map(d -> d.getName()).toArray(String[]::new);
FunctionType ftype = fsb.getFunctionOpType();
FunctionOp fcall = new FunctionOp(ftype, fci.getNamespace(), fci.getName(),
inputNames.toArray(new String[0]), finputs, foutputs, false);
fcall.setParseInfo(fci);
output.add(fcall);
}
else if ( source instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression)source).multipleReturns() ) {
// construct input hops
Hop fcall = processMultipleReturnBuiltinFunctionExpression((BuiltinFunctionExpression)source, mas.getTargetList(), ids);
output.add(fcall);
}
else if ( source instanceof ParameterizedBuiltinFunctionExpression && ((ParameterizedBuiltinFunctionExpression)source).multipleReturns() ) {
// construct input hops
Hop fcall = processMultipleReturnParameterizedBuiltinFunctionExpression((ParameterizedBuiltinFunctionExpression)source, mas.getTargetList(), ids);
output.add(fcall);
}
else
throw new LanguageException("Class \"" + source.getClass() + "\" is not supported in Multiple Assignment statements");
}
}
sb.updateLiveVariablesOut(updatedLiveOut);
sb.setHops(output);
}