in libraries/value/src/loopnests/LoopNestPrinter.cpp [202:342]
std::string LoopNestPrinter::GetPredicateString(const KernelPredicate& predicate, const LoopIndexSymbolTable& runtimeIndexVariables, const LoopVisitSchedule& schedule) const
{
if (predicate.IsAlwaysTrue())
{
return "true";
}
else if (predicate.IsAlwaysFalse())
{
return "false";
}
else if (auto fragmentPred = predicate.As<FragmentTypePredicate>(); fragmentPred != nullptr)
{
auto condition = fragmentPred->GetCondition();
if (condition == Fragment::all)
{
return "true";
}
auto index = fragmentPred->GetIndex();
const auto& domain = schedule.GetLoopNest().GetDomain();
const auto range = domain.GetDimensionRange(index);
auto loopIndices = range.GetDependentLoopIndices(index);
if (loopIndices.empty())
{
loopIndices = { index };
}
bool first = true;
std::string result = "";
for (auto loopIndex : loopIndices)
{
auto range = GetLoopRange(loopIndex, runtimeIndexVariables, schedule);
int testVal = 0;
bool valid = true;
switch (condition)
{
case Fragment::first:
testVal = range.Begin();
break;
case Fragment::last:
testVal = range.End() - (range.Size() % range.Increment());
if (testVal == range.End()) // not a boundary
{
testVal = range.End() - range.Increment();
}
break;
case Fragment::endBoundary:
testVal = range.End() - (range.Size() % range.Increment());
if (testVal == range.End()) // not a boundary
{
valid = false;
}
break;
default:
valid = false;
// throw?
break;
}
if (valid)
{
if (first)
{
result = "(";
}
else
{
result += " && ";
}
first = false;
result += "(" + GetIndexString(loopIndex, runtimeIndexVariables) + " == " + std::to_string(testVal) + ")";
}
}
return result.empty() ? "" : result + ")";
}
else if (predicate.Is<IndexDefinedPredicate>())
{
throw utilities::LogicException(utilities::LogicExceptionErrors::notImplemented, "IsDefined predicate not implemented");
}
else if (auto conjunction = predicate.As<KernelPredicateConjunction>(); conjunction != nullptr)
{
const auto& terms = conjunction->GetTerms();
if (terms.size() == 0)
{
return "true";
}
else if (terms.size() == 1)
{
return GetPredicateString(*terms[0], runtimeIndexVariables, schedule);
}
else
{
std::string result = "(";
bool first = true;
for (const auto& t : terms)
{
if (!first)
{
result += " && ";
}
first = false;
result += GetPredicateString(*t, runtimeIndexVariables, schedule);
}
result += ")";
return result;
}
}
else if (auto disjunction = predicate.As<KernelPredicateDisjunction>(); disjunction != nullptr)
{
const auto& terms = disjunction->GetTerms();
if (terms.size() == 0)
{
return "true";
}
else if (terms.size() == 1)
{
return GetPredicateString(*terms[0], runtimeIndexVariables, schedule);
}
else
{
std::string result = "(";
bool first = true;
for (const auto& t : terms)
{
result += GetPredicateString(*t, runtimeIndexVariables, schedule);
if (!first)
{
result += " || ";
}
first = false;
}
result += ")";
return result;
}
}
else
{
throw utilities::LogicException(utilities::LogicExceptionErrors::illegalState, "Unknown predicate type");
}
}