in src/main/java/org/apache/sysds/parser/IfStatementBlock.java [40:283]
public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap<String,ConstIdentifier> constVars, boolean conditional)
{
if (_statements.size() > 1){
raiseValidateError("IfStatementBlock should only have 1 statement (IfStatement)", conditional);
}
IfStatement ifstmt = (IfStatement) _statements.get(0);
//validate conditional predicate (incl constant propagation)
Expression pred = ifstmt.getConditionalPredicate().getPredicate();
pred.validateExpression(ids.getVariables(), constVars, conditional);
if( pred instanceof DataIdentifier && constVars.containsKey( ((DataIdentifier)pred).getName()) && !conditional ) {
ifstmt.getConditionalPredicate().setPredicate(constVars.get(((DataIdentifier)pred).getName()));
}
HashMap<String,ConstIdentifier> constVarsIfCopy = new HashMap<>(constVars);
HashMap<String,ConstIdentifier> constVarsElseCopy = new HashMap<> (constVars);
VariableSet idsIfCopy = new VariableSet(ids);
VariableSet idsElseCopy = new VariableSet(ids);
VariableSet idsOrigCopy = new VariableSet(ids);
// handle if stmt body
_dmlProg = dmlProg;
ArrayList<StatementBlock> ifBody = ifstmt.getIfBody();
for(StatementBlock sb : ifBody){ //conditional exec
idsIfCopy = sb.validate(dmlProg, idsIfCopy, constVarsIfCopy, true);
constVarsIfCopy = sb.getConstOut();
}
// handle else stmt body
ArrayList<StatementBlock> elseBody = ifstmt.getElseBody();
for(StatementBlock sb : elseBody){ //conditional exec
idsElseCopy = sb.validate(dmlProg,idsElseCopy, constVarsElseCopy, true);
constVarsElseCopy = sb.getConstOut();
}
/////////////////////////////////////////////////////////////////////////////////
// check data type and value type are same for updated variables in both
// if statement and else statement
// (reject conditional data type change)
/////////////////////////////////////////////////////////////////////////////////
for (String updatedVar : this._updated.getVariableNames()){
DataIdentifier origVersion = idsOrigCopy.getVariable(updatedVar);
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier elseVersion = idsElseCopy.getVariable(updatedVar);
//data type handling: reject conditional data type change
if( ifVersion != null && elseVersion != null ) //both branches exist
{
if (!ifVersion.getOutput().getDataType().equals(elseVersion.getOutput().getDataType())){
raiseValidateError("IfStatementBlock has unsupported conditional data type change of variable '"+updatedVar+"' in if/else branch.", conditional);
}
}
else if( origVersion !=null ) //only if branch exists
{
if (!ifVersion.getOutput().getDataType().equals(origVersion.getOutput().getDataType())){
raiseValidateError("IfStatementBlock has unsupported conditional data type change of variable '"+updatedVar+"' in if branch.", conditional);
}
}
//value type handling
if (ifVersion != null && elseVersion != null && !ifVersion.getOutput().getValueType().equals(elseVersion.getOutput().getValueType())){
LOG.warn(elseVersion.printWarningLocation() + "Variable " + elseVersion.getName() + " defined with different value type in if and else clause.");
}
}
// handle constant variable propagation -- (IF UNION ELSE) MINUS updated vars
//////////////////////////////////////////////////////////////////////////////////
// handle constant variables
// 1) (IF UNION ELSE) MINUS updated const vars
// 2) reconcile updated const vars
// a) IF updated const variables have same value and datatype in both if / else branch, THEN set updated size to updated size
// b) ELSE leave out of reconciled set
/////////////////////////////////////////////////////////////////////////////////
HashMap<String,ConstIdentifier> recConstVars = new HashMap<>();
// STEP 1: (IF UNION ELSE) MINUS updated vars
for (Entry<String,ConstIdentifier> e : constVarsIfCopy.entrySet() ){
String varName = e.getKey();
if (!_updated.containsVariable(varName))
recConstVars.put(varName, e.getValue());
}
for (Entry<String,ConstIdentifier> e : constVarsElseCopy.entrySet() ){
String varName = e.getKey();
if (!_updated.containsVariable(varName))
recConstVars.put(varName, e.getValue());
}
// STEP 2: check that updated const values have in both if / else branches
// a) same data type,
// b) same value type (SCALAR),
// c) same value
for (String updatedVar : _updated.getVariableNames()){
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier elseVersion = idsElseCopy.getVariable(updatedVar);
if (ifVersion != null && elseVersion != null
&& ifVersion.getOutput().getDataType().equals(DataType.SCALAR)
&& elseVersion.getOutput().getDataType().equals(DataType.SCALAR)
&& ifVersion.getOutput().getValueType().equals(elseVersion.getOutput().getValueType()))
{
ConstIdentifier ifConstVersion = constVarsIfCopy.get(updatedVar);
ConstIdentifier elseConstVersion = constVarsElseCopy.get(updatedVar);
// IntIdentifier
if (ifConstVersion != null && elseConstVersion != null && ifConstVersion instanceof IntIdentifier && elseConstVersion instanceof IntIdentifier){
if ( ((IntIdentifier)ifConstVersion).getValue() == ((IntIdentifier) elseConstVersion).getValue() )
recConstVars.put(updatedVar, ifConstVersion);
}
// DoubleIdentifier
else if (ifConstVersion != null && elseConstVersion != null && ifConstVersion instanceof DoubleIdentifier && elseConstVersion instanceof DoubleIdentifier){
if ( ((DoubleIdentifier)ifConstVersion).getValue() == ((DoubleIdentifier) elseConstVersion).getValue() )
recConstVars.put(updatedVar, ifConstVersion);
}
// Boolean
else if (ifConstVersion != null && elseConstVersion != null && ifConstVersion instanceof BooleanIdentifier && elseConstVersion instanceof BooleanIdentifier){
if ( ((BooleanIdentifier)ifConstVersion).getValue() == ((BooleanIdentifier) elseConstVersion).getValue() )
recConstVars.put(updatedVar, ifConstVersion);
}
// String
else if (ifConstVersion != null && elseConstVersion != null && ifConstVersion instanceof StringIdentifier && elseConstVersion instanceof StringIdentifier){
if ( ((StringIdentifier)ifConstVersion).getValue().equals(((StringIdentifier) elseConstVersion).getValue()) )
recConstVars.put(updatedVar, ifConstVersion);
}
}
}
//////////////////////////////////////////////////////////////////////////////////
// handle DataIdentifier variables
// 1) (IF UNION ELSE) MINUS updated vars
// 2) reconcile size updated variables
// a) IF updated variables have same size in both if / else branch, THEN set updated size to updated size
// b) ELSE set size updated to (-1,-1)
// 3) add updated vars to reconciled set
/////////////////////////////////////////////////////////////////////////////////
// STEP 1: (IF UNION ELSE) MINUS updated vars
VariableSet recVars = new VariableSet();
for (String varName : idsIfCopy.getVariableNames()){
if (!_updated.containsVariable(varName))
recVars.addVariable(varName,idsIfCopy.getVariable(varName));
}
for (String varName : idsElseCopy.getVariableNames()){
if (!_updated.containsVariable(varName))
recVars.addVariable(varName,idsElseCopy.getVariable(varName));
}
// STEP 2: reconcile size of updated variables
for (String updatedVar : _updated.getVariableNames()){
DataIdentifier ifVersion = idsIfCopy.getVariable(updatedVar);
DataIdentifier elseVersion = idsElseCopy.getVariable(updatedVar);
DataIdentifier origVersion = idsOrigCopy.getVariable(updatedVar);
if (ifVersion != null && elseVersion != null) {
long updatedDim1 = -1, updatedDim2 = -1;
long updatedNnz = -1;
long ifVersionDim1 = (ifVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)ifVersion).getOrigDim1() : ifVersion.getDim1();
long elseVersionDim1 = (elseVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)elseVersion).getOrigDim1() : elseVersion.getDim1();
long ifVersionDim2 = (ifVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)ifVersion).getOrigDim2() : ifVersion.getDim2();
long elseVersionDim2 = (elseVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)elseVersion).getOrigDim2() : elseVersion.getDim2();
if( ifVersionDim1 == elseVersionDim1 ){
updatedDim1 = ifVersionDim1;
}
if( ifVersionDim2 == elseVersionDim2 ){
updatedDim2 = ifVersionDim2;
}
//NOTE: nnz not propagated via validate, and hence, we conservatively assume that nnz have been changed.
//if( ifVersion.getNnz() == elseVersion.getNnz() ){
// updatedNnz = ifVersion.getNnz();
//}
// add reconsiled version (deep copy of ifVersion, cast as DataIdentifier)
DataIdentifier recVersion = new DataIdentifier(ifVersion);
recVersion.setDimensions(updatedDim1, updatedDim2);
recVersion.setNnz(updatedNnz);
recVars.addVariable(updatedVar, recVersion);
}
else {
// CASE: defined only if branch
DataIdentifier recVersion = null;
if (ifVersion != null){
// add reconciled version (deep copy of ifVersion, cast as DataIdentifier)
recVersion = new DataIdentifier(ifVersion);
recVars.addVariable(updatedVar, recVersion);
}
// CASE: defined only else branch
else if (elseVersion != null){
// add reconciled version (deep copy of elseVersion, cast as DataIdentifier)
recVersion = new DataIdentifier(elseVersion);
recVars.addVariable(updatedVar, recVersion);
}
// CASE: updated, but not in either if or else branch
else {
// add reconciled version (deep copy of elseVersion, cast as DataIdentifier)
recVersion = new DataIdentifier(_updated.getVariable(updatedVar));
recVars.addVariable(updatedVar, recVersion);
}
long updatedDim1 = -1, updatedDim2 = -1;
long updatedNnz = -1;
if( origVersion != null ) {
long origVersionDim1 = (origVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)origVersion).getOrigDim1() : origVersion.getDim1();
long recVersionDim1 = recVersion.getDim1(); //always DataIdentifier (see above)
long origVersionDim2 = (origVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)origVersion).getOrigDim2() : origVersion.getDim2();
long recVersionDim2 = recVersion.getDim2(); //always DataIdentifier (see above)
if( origVersionDim1 == recVersionDim1 ){
updatedDim1 = origVersionDim1;
}
if( origVersionDim2 == recVersionDim2 ){
updatedDim2 = origVersionDim2;
}
//NOTE: nnz not propagated via validate, and hence, we conservatively assume that nnz have been changed.
//if( origVersion.getNnz() == recVersion.getNnz() ){
// updatedNnz = recVersion.getNnz();
//}
}
recVersion.setDimensions(updatedDim1, updatedDim2);
recVersion.setNnz(updatedNnz);
}
}
// propagate updated variables
VariableSet allIdVars = new VariableSet(recVars);
_constVarsIn.putAll(constVars);
_constVarsOut.putAll(recConstVars);
return allIdVars;
}