Scalar CodeGenerator::EmitKernelPredicate()

in libraries/value/src/loopnests/CodeGenerator.cpp [190:345]


        Scalar CodeGenerator::EmitKernelPredicate(const KernelPredicate& predicate, const LoopIndexSymbolTable& runtimeIndexVariables, const LoopVisitSchedule& schedule) const
        {
            const auto& domain = schedule.GetLoopNest().GetDomain();
            auto predResult = MakeScalar<int>("predResult");
            predResult = 1; // "true"

            auto emitPredicate = [&domain, &runtimeIndexVariables, &schedule](const auto& emitPredicate, const KernelPredicate& p, Scalar& result, bool defaultIsTrue) -> void {
                if (p.IsAlwaysTrue())
                {
                    if (defaultIsTrue)
                    {
                        // nothing
                    }
                    else
                    {
                        result = Scalar(1); // "true"
                    }
                }
                if (p.IsAlwaysFalse())
                {
                    if (defaultIsTrue)
                    {
                        result = Scalar(0); // "false"
                    }
                    else
                    {
                        // nothing
                    }
                }
                else if (auto simplePredicate = p.As<FragmentTypePredicate>(); simplePredicate != nullptr)
                {
                    auto condition = simplePredicate->GetCondition();
                    if (condition == Fragment::all)
                    {
                        return; // do nothing for 'all' predicates
                    }

                    auto index = simplePredicate->GetIndex();
                    const auto range = domain.GetDimensionRange(index);

                    auto loopIndices = range.GetDependentLoopIndices(index);
                    if (loopIndices.empty())
                    {
                        loopIndices = { index };
                    }
                    for (auto loopIndex : loopIndices)
                    {
                        auto range = GetLoopRange(loopIndex, runtimeIndexVariables, schedule);

                        int testVal = 0;
                        bool valid = true;
                        switch (condition)
                        {
                        case Fragment::first:
                            testVal = range.Begin();
                            break;
                        case Fragment::last:
                            testVal = range.End() - (range.Size() % range.Increment());
                            if (testVal == range.End()) // not a boundary
                            {
                                testVal = range.End() - range.Increment();
                            }
                            break;
                        case Fragment::endBoundary:
                            testVal = range.End() - (range.Size() % range.Increment());
                            if (testVal == range.End())
                            {
                                valid = false;
                            }
                            break;
                        default:
                            // throw?
                            valid = false;
                            break;
                        }

                        if (valid)
                        {
                            // if loop index not present, assume 0
                            Scalar indexVal = MakeScalar<int>("predIndexVal");
                            if (runtimeIndexVariables.count(loopIndex) != 0)
                            {
                                indexVal = runtimeIndexVariables.at(loopIndex).value;
                            }

                            if (defaultIsTrue)
                            {
                                If(indexVal != testVal, [&] {
                                    result = Scalar(0); // "false"
                                });
                            }
                            else
                            {
                                If(indexVal == testVal, [&] {
                                    result = Scalar(1); // "true"
                                });
                            }
                        }
                    }
                }
                else if (p.Is<IndexDefinedPredicate>())
                {
                    throw utilities::LogicException(utilities::LogicExceptionErrors::notImplemented, "IsDefined predicate not implemented");
                }
                else if (auto conjunction = p.As<KernelPredicateConjunction>(); conjunction != nullptr)
                {
                    auto conjResult = MakeScalar<int>("conj");
                    conjResult = Scalar(1); // "true"
                    for (const auto& t : conjunction->GetTerms())
                    {
                        emitPredicate(emitPredicate, *t, conjResult, true);
                    }

                    if (defaultIsTrue)
                    {
                        If(conjResult == 0, [&result] {
                            result = Scalar(0); // "false"
                        });
                    }
                    else
                    {
                        If(conjResult != 0, [&result] {
                            result = Scalar(1); // "true"
                        });
                    }
                }
                else if (auto disjunction = p.As<KernelPredicateDisjunction>(); disjunction != nullptr)
                {
                    auto disjResult = MakeScalar<int>("disj");
                    disjResult = Scalar(0); // "false"
                    for (const auto& t : disjunction->GetTerms())
                    {
                        emitPredicate(emitPredicate, *t, disjResult, false);
                    }
                    if (defaultIsTrue)
                    {
                        If(disjResult == 0, [&result] {
                            result = Scalar(0); // "false"
                        });
                    }
                    else
                    {
                        If(disjResult != 0, [&result] {
                            result = Scalar(1); // "true"
                        });
                    }
                }
                else
                {
                    throw utilities::LogicException(utilities::LogicExceptionErrors::illegalState, "Unknown predicate type");
                }
            };

            emitPredicate(emitPredicate, predicate, predResult, true);
            return predResult;
        }