bool LoopNestVisitor::IsPlacementValid()

in libraries/value/src/loopnests/LoopNestVisitor.cpp [727:837]


        bool LoopNestVisitor::IsPlacementValid(const ScheduledKernel& kernel, const LoopIndexSymbolTable& runtimeLoopIndices, const LoopVisitSchedule& schedule) const
        {
            const auto& domain = schedule.GetDomain();
            if (kernel.placement.IsEmpty() || IsBodyPlacementPredicate(kernel.placement))
            {
                // TODO: put this in a function that preprocesses the kernel predicates when adding the kernels to the schedule
                for (const auto& kernelIndex : kernel.kernel.GetIndices())
                {
                    for (const auto& loopIndex : domain.GetDependentLoopIndices(kernelIndex, true))
                    {
                        // if not defined(loopIndex) return false;
                        if (runtimeLoopIndices.count(loopIndex) == 0 || runtimeLoopIndices.at(loopIndex).state == LoopIndexState::done)
                        {
                            return false;
                        }
                    }
                }

                if (kernel.placement.IsEmpty())
                {
                    return true;
                }
            }

            auto evalPlacement = [&](const auto& evalPlacement, const KernelPredicate& p) -> bool {
                if (p.IsAlwaysTrue())
                {
                    return true;
                }
                else if (p.Is<FragmentTypePredicate>())
                {
                    throw utilities::InputException(utilities::InputExceptionErrors::invalidArgument, "Fragment predicates not valid for placement");
                }
                else if (auto placementPred = p.As<PlacementPredicate>(); placementPred != nullptr)
                {
                    if (schedule.IsInnermostLoop())
                    {
                        return !placementPred->HasIndex();
                    }

                    auto nextLoopIndex = schedule.Next().CurrentLoopIndex();
                    auto where = placementPred->GetPlacement();

                    std::vector<Index> dependentLoopIndices;
                    if (placementPred->HasIndex())
                    {
                        auto testIndex = placementPred->GetIndex();

                        // get list of dependent indices
                        dependentLoopIndices = domain.GetDependentLoopIndices(testIndex, true);

                        // First check that we're not already inside any dependent loops
                        for (const auto& i : dependentLoopIndices)
                        {
                            if (runtimeLoopIndices.count(i) != 0 && runtimeLoopIndices.at(i).state == LoopIndexState::inProgress)
                            {
                                return false;
                            }
                        }
                    }
                    else
                    {
                        dependentLoopIndices = { nextLoopIndex };
                    }

                    // Now check that the next loop at least partially defines the index in question
                    if (std::find(dependentLoopIndices.begin(), dependentLoopIndices.end(), nextLoopIndex) != dependentLoopIndices.end())
                    {
                        // Finally, check that we're in the correct position (before vs. after)
                        if (where == Placement::before)
                        {
                            return (runtimeLoopIndices.count(nextLoopIndex) == 0 || runtimeLoopIndices.at(nextLoopIndex).state == LoopIndexState::notVisited);
                        }
                        else // (where == Placement::after)
                        {
                            return (runtimeLoopIndices.count(nextLoopIndex) != 0 && runtimeLoopIndices.at(nextLoopIndex).state == LoopIndexState::done);
                        }
                    }
                    return false;
                }
                else if (auto definedPred = p.As<IndexDefinedPredicate>(); definedPred != nullptr)
                {
                    auto definedIndex = definedPred->GetIndex();
                    return (runtimeLoopIndices.count(definedIndex) > 0) && (runtimeLoopIndices.at(definedIndex).state != LoopIndexState::done);
                }
                else if (auto conjunction = p.As<KernelPredicateConjunction>(); conjunction != nullptr)
                {
                    bool result = true;
                    for (const auto& t : conjunction->GetTerms())
                    {
                        result &= evalPlacement(evalPlacement, *t);
                    }
                    return result;
                }
                else if (auto disjunction = p.As<KernelPredicateDisjunction>(); disjunction != nullptr)
                {
                    bool result = false;
                    for (const auto& t : disjunction->GetTerms())
                    {
                        result |= evalPlacement(evalPlacement, *t);
                    }
                    return result;
                }
                else
                {
                    return false;
                }
            };

            return evalPlacement(evalPlacement, kernel.placement);
        }