in libraries/value/src/loopnests/LoopNestVisitor.cpp [727:837]
bool LoopNestVisitor::IsPlacementValid(const ScheduledKernel& kernel, const LoopIndexSymbolTable& runtimeLoopIndices, const LoopVisitSchedule& schedule) const
{
const auto& domain = schedule.GetDomain();
if (kernel.placement.IsEmpty() || IsBodyPlacementPredicate(kernel.placement))
{
// TODO: put this in a function that preprocesses the kernel predicates when adding the kernels to the schedule
for (const auto& kernelIndex : kernel.kernel.GetIndices())
{
for (const auto& loopIndex : domain.GetDependentLoopIndices(kernelIndex, true))
{
// if not defined(loopIndex) return false;
if (runtimeLoopIndices.count(loopIndex) == 0 || runtimeLoopIndices.at(loopIndex).state == LoopIndexState::done)
{
return false;
}
}
}
if (kernel.placement.IsEmpty())
{
return true;
}
}
auto evalPlacement = [&](const auto& evalPlacement, const KernelPredicate& p) -> bool {
if (p.IsAlwaysTrue())
{
return true;
}
else if (p.Is<FragmentTypePredicate>())
{
throw utilities::InputException(utilities::InputExceptionErrors::invalidArgument, "Fragment predicates not valid for placement");
}
else if (auto placementPred = p.As<PlacementPredicate>(); placementPred != nullptr)
{
if (schedule.IsInnermostLoop())
{
return !placementPred->HasIndex();
}
auto nextLoopIndex = schedule.Next().CurrentLoopIndex();
auto where = placementPred->GetPlacement();
std::vector<Index> dependentLoopIndices;
if (placementPred->HasIndex())
{
auto testIndex = placementPred->GetIndex();
// get list of dependent indices
dependentLoopIndices = domain.GetDependentLoopIndices(testIndex, true);
// First check that we're not already inside any dependent loops
for (const auto& i : dependentLoopIndices)
{
if (runtimeLoopIndices.count(i) != 0 && runtimeLoopIndices.at(i).state == LoopIndexState::inProgress)
{
return false;
}
}
}
else
{
dependentLoopIndices = { nextLoopIndex };
}
// Now check that the next loop at least partially defines the index in question
if (std::find(dependentLoopIndices.begin(), dependentLoopIndices.end(), nextLoopIndex) != dependentLoopIndices.end())
{
// Finally, check that we're in the correct position (before vs. after)
if (where == Placement::before)
{
return (runtimeLoopIndices.count(nextLoopIndex) == 0 || runtimeLoopIndices.at(nextLoopIndex).state == LoopIndexState::notVisited);
}
else // (where == Placement::after)
{
return (runtimeLoopIndices.count(nextLoopIndex) != 0 && runtimeLoopIndices.at(nextLoopIndex).state == LoopIndexState::done);
}
}
return false;
}
else if (auto definedPred = p.As<IndexDefinedPredicate>(); definedPred != nullptr)
{
auto definedIndex = definedPred->GetIndex();
return (runtimeLoopIndices.count(definedIndex) > 0) && (runtimeLoopIndices.at(definedIndex).state != LoopIndexState::done);
}
else if (auto conjunction = p.As<KernelPredicateConjunction>(); conjunction != nullptr)
{
bool result = true;
for (const auto& t : conjunction->GetTerms())
{
result &= evalPlacement(evalPlacement, *t);
}
return result;
}
else if (auto disjunction = p.As<KernelPredicateDisjunction>(); disjunction != nullptr)
{
bool result = false;
for (const auto& t : disjunction->GetTerms())
{
result |= evalPlacement(evalPlacement, *t);
}
return result;
}
else
{
return false;
}
};
return evalPlacement(evalPlacement, kernel.placement);
}