std::string LoopNestPrinter::GetPredicateString()

in libraries/value/src/loopnests/LoopNestPrinter.cpp [202:342]


        std::string LoopNestPrinter::GetPredicateString(const KernelPredicate& predicate, const LoopIndexSymbolTable& runtimeIndexVariables, const LoopVisitSchedule& schedule) const
        {
            if (predicate.IsAlwaysTrue())
            {
                return "true";
            }
            else if (predicate.IsAlwaysFalse())
            {
                return "false";
            }
            else if (auto fragmentPred = predicate.As<FragmentTypePredicate>(); fragmentPred != nullptr)
            {
                auto condition = fragmentPred->GetCondition();
                if (condition == Fragment::all)
                {
                    return "true";
                }

                auto index = fragmentPred->GetIndex();
                const auto& domain = schedule.GetLoopNest().GetDomain();
                const auto range = domain.GetDimensionRange(index);

                auto loopIndices = range.GetDependentLoopIndices(index);
                if (loopIndices.empty())
                {
                    loopIndices = { index };
                }
                bool first = true;
                std::string result = "";
                for (auto loopIndex : loopIndices)
                {
                    auto range = GetLoopRange(loopIndex, runtimeIndexVariables, schedule);

                    int testVal = 0;
                    bool valid = true;
                    switch (condition)
                    {
                    case Fragment::first:
                        testVal = range.Begin();
                        break;
                    case Fragment::last:
                        testVal = range.End() - (range.Size() % range.Increment());
                        if (testVal == range.End()) // not a boundary
                        {
                            testVal = range.End() - range.Increment();
                        }
                        break;
                    case Fragment::endBoundary:
                        testVal = range.End() - (range.Size() % range.Increment());
                        if (testVal == range.End()) // not a boundary
                        {
                            valid = false;
                        }
                        break;
                    default:
                        valid = false;
                        // throw?
                        break;
                    }

                    if (valid)
                    {
                        if (first)
                        {
                            result = "(";
                        }
                        else
                        {
                            result += " && ";
                        }
                        first = false;
                        result += "(" + GetIndexString(loopIndex, runtimeIndexVariables) + " == " + std::to_string(testVal) + ")";
                    }
                }
                return result.empty() ? "" : result + ")";
            }
            else if (predicate.Is<IndexDefinedPredicate>())
            {
                throw utilities::LogicException(utilities::LogicExceptionErrors::notImplemented, "IsDefined predicate not implemented");
            }
            else if (auto conjunction = predicate.As<KernelPredicateConjunction>(); conjunction != nullptr)
            {
                const auto& terms = conjunction->GetTerms();
                if (terms.size() == 0)
                {
                    return "true";
                }
                else if (terms.size() == 1)
                {
                    return GetPredicateString(*terms[0], runtimeIndexVariables, schedule);
                }
                else
                {
                    std::string result = "(";
                    bool first = true;
                    for (const auto& t : terms)
                    {
                        if (!first)
                        {
                            result += " && ";
                        }
                        first = false;
                        result += GetPredicateString(*t, runtimeIndexVariables, schedule);
                    }
                    result += ")";
                    return result;
                }
            }
            else if (auto disjunction = predicate.As<KernelPredicateDisjunction>(); disjunction != nullptr)
            {
                const auto& terms = disjunction->GetTerms();
                if (terms.size() == 0)
                {
                    return "true";
                }
                else if (terms.size() == 1)
                {
                    return GetPredicateString(*terms[0], runtimeIndexVariables, schedule);
                }
                else
                {
                    std::string result = "(";
                    bool first = true;
                    for (const auto& t : terms)
                    {
                        result += GetPredicateString(*t, runtimeIndexVariables, schedule);
                        if (!first)
                        {
                            result += " || ";
                        }
                        first = false;
                    }
                    result += ")";
                    return result;
                }
            }
            else
            {
                throw utilities::LogicException(utilities::LogicExceptionErrors::illegalState, "Unknown predicate type");
            }
        }