in libraries/value/src/loopnests/CodeGenerator.cpp [190:345]
Scalar CodeGenerator::EmitKernelPredicate(const KernelPredicate& predicate, const LoopIndexSymbolTable& runtimeIndexVariables, const LoopVisitSchedule& schedule) const
{
const auto& domain = schedule.GetLoopNest().GetDomain();
auto predResult = MakeScalar<int>("predResult");
predResult = 1; // "true"
auto emitPredicate = [&domain, &runtimeIndexVariables, &schedule](const auto& emitPredicate, const KernelPredicate& p, Scalar& result, bool defaultIsTrue) -> void {
if (p.IsAlwaysTrue())
{
if (defaultIsTrue)
{
// nothing
}
else
{
result = Scalar(1); // "true"
}
}
if (p.IsAlwaysFalse())
{
if (defaultIsTrue)
{
result = Scalar(0); // "false"
}
else
{
// nothing
}
}
else if (auto simplePredicate = p.As<FragmentTypePredicate>(); simplePredicate != nullptr)
{
auto condition = simplePredicate->GetCondition();
if (condition == Fragment::all)
{
return; // do nothing for 'all' predicates
}
auto index = simplePredicate->GetIndex();
const auto range = domain.GetDimensionRange(index);
auto loopIndices = range.GetDependentLoopIndices(index);
if (loopIndices.empty())
{
loopIndices = { index };
}
for (auto loopIndex : loopIndices)
{
auto range = GetLoopRange(loopIndex, runtimeIndexVariables, schedule);
int testVal = 0;
bool valid = true;
switch (condition)
{
case Fragment::first:
testVal = range.Begin();
break;
case Fragment::last:
testVal = range.End() - (range.Size() % range.Increment());
if (testVal == range.End()) // not a boundary
{
testVal = range.End() - range.Increment();
}
break;
case Fragment::endBoundary:
testVal = range.End() - (range.Size() % range.Increment());
if (testVal == range.End())
{
valid = false;
}
break;
default:
// throw?
valid = false;
break;
}
if (valid)
{
// if loop index not present, assume 0
Scalar indexVal = MakeScalar<int>("predIndexVal");
if (runtimeIndexVariables.count(loopIndex) != 0)
{
indexVal = runtimeIndexVariables.at(loopIndex).value;
}
if (defaultIsTrue)
{
If(indexVal != testVal, [&] {
result = Scalar(0); // "false"
});
}
else
{
If(indexVal == testVal, [&] {
result = Scalar(1); // "true"
});
}
}
}
}
else if (p.Is<IndexDefinedPredicate>())
{
throw utilities::LogicException(utilities::LogicExceptionErrors::notImplemented, "IsDefined predicate not implemented");
}
else if (auto conjunction = p.As<KernelPredicateConjunction>(); conjunction != nullptr)
{
auto conjResult = MakeScalar<int>("conj");
conjResult = Scalar(1); // "true"
for (const auto& t : conjunction->GetTerms())
{
emitPredicate(emitPredicate, *t, conjResult, true);
}
if (defaultIsTrue)
{
If(conjResult == 0, [&result] {
result = Scalar(0); // "false"
});
}
else
{
If(conjResult != 0, [&result] {
result = Scalar(1); // "true"
});
}
}
else if (auto disjunction = p.As<KernelPredicateDisjunction>(); disjunction != nullptr)
{
auto disjResult = MakeScalar<int>("disj");
disjResult = Scalar(0); // "false"
for (const auto& t : disjunction->GetTerms())
{
emitPredicate(emitPredicate, *t, disjResult, false);
}
if (defaultIsTrue)
{
If(disjResult == 0, [&result] {
result = Scalar(0); // "false"
});
}
else
{
If(disjResult != 0, [&result] {
result = Scalar(1); // "true"
});
}
}
else
{
throw utilities::LogicException(utilities::LogicExceptionErrors::illegalState, "Unknown predicate type");
}
};
emitPredicate(emitPredicate, predicate, predResult, true);
return predResult;
}