in meta/src/main/java/org/arend/lib/meta/linear/LinearSolver.java [418:599]
public TypedExpression solve(CoreExpression expectedType, ConcreteExpression hint) {
expectedType = expectedType.normalize(NormalizationMode.WHNF);
Equation<CoreExpression> resultEquation;
if (expectedType instanceof CoreDataCallExpression && ((CoreDataCallExpression) expectedType).getDefinition() == ext.Empty) {
resultEquation = null;
} else {
resultEquation = typeToEquation(expectedType, null, true);
if (resultEquation == null) return null;
}
List<Hypothesis<CoreExpression>> rules = new ArrayList<>();
ContextHelper helper = new ContextHelper(hint);
for (CoreBinding binding : helper.getAllBindings(typechecker)) {
Hypothesis<CoreExpression> hypothesis = bindingToHypothesis(binding);
if (hypothesis != null) rules.add(hypothesis);
}
if (resultEquation != null) {
TypedExpression instance = resultEquation.instance.computeTyped();
CoreClassCallExpression classCall = Utils.getClassCall(instance.getType());
TermCompiler compiler = makeTermCompiler(instance, classCall);
if (compiler != null) {
List<Hypothesis<CoreExpression>> newRules = new ArrayList<>();
for (Hypothesis<CoreExpression> rule : rules) {
if (rule.instance.compare(resultEquation.instance, CMP.EQ)) {
newRules.add(rule);
} else {
Hypothesis<CoreExpression> newRule = convertHypothesis(rule, resultEquation.instance, BaseTermCompiler.getTermCompilerKind(rule.instance, ext.equationMeta), compiler.getKind());
if (newRule != null) {
newRules.add(newRule);
}
}
}
rules = newRules;
CoreFunctionDefinition function;
List<Hypothesis<CompiledTerm>> compiledRules = new ArrayList<>();
compileHypotheses(compiler, rules, compiledRules);
List<List<Equation<CompiledTerm>>> rulesSet = new ArrayList<>(2);
List<Equation<CompiledTerm>> compiledRules1 = new ArrayList<>();
compiledRules1.add(makeZeroLessOne(instance.getExpression()));
rulesSet.add(compiledRules1);
CompiledTerms compiledResults = compiler.compileTerms(resultEquation.lhsTerm, resultEquation.rhsTerm);
if (compiler.isNat() || !compiler.positiveVars.isEmpty()) {
makeZeroLessVar(instance.getExpression(), compiler, compiledRules);
}
switch (resultEquation.operation) {
case LESS -> {
compiledRules1.add(new Hypothesis<>(null, resultEquation.instance, Equation.Operation.LESS_OR_EQUALS, compiledResults.term2, compiledResults.term1, compiledResults.lcm));
function = ext.linearSolverMeta.solveLessProblem;
}
case LESS_OR_EQUALS -> {
compiledRules1.add(new Hypothesis<>(null, resultEquation.instance, Equation.Operation.LESS, compiledResults.term2, compiledResults.term1, compiledResults.lcm));
function = ext.linearSolverMeta.solveLeqProblem;
}
case EQUALS -> {
List<Equation<CompiledTerm>> compiledRules2 = new ArrayList<>(compiledRules1);
compiledRules1.add(new Hypothesis<>(null, resultEquation.instance, Equation.Operation.LESS, compiledResults.term1, compiledResults.term2, compiledResults.lcm));
compiledRules2.add(new Hypothesis<>(null, resultEquation.instance, Equation.Operation.LESS, compiledResults.term2, compiledResults.term1, compiledResults.lcm));
compiledRules2.addAll(compiledRules);
rulesSet.add(compiledRules2);
function = ext.linearSolverMeta.solveEqProblem;
}
default -> throw new IllegalStateException();
}
compiledRules1.addAll(compiledRules);
List<List<BigInteger>> solutions = new ArrayList<>(rulesSet.size());
for (List<Equation<CompiledTerm>> equations : rulesSet) {
List<BigInteger> solution = solveEquations(equations, compiler.getNumberOfVariables());
if (solution != null) solutions.add(solution);
}
if (solutions.size() == rulesSet.size()) {
List<BigInteger> combinedSolutions = new ArrayList<>();
for (List<BigInteger> solution : solutions) {
for (int i = combinedSolutions.size() + 2; i < solution.size(); i++) {
combinedSolutions.add(BigInteger.ZERO);
}
for (int i = 0; i < solution.size() - 2; i++) {
combinedSolutions.set(i, combinedSolutions.get(i).max(solution.get(i + 2)));
}
}
dropUnusedHypotheses(combinedSolutions, compiledRules);
List<CoreExpression> values = compiler.getValues().getValues();
List<Hypothesis<CompiledTerm>> newCompiledRules = new ArrayList<>(compiledRules);
newCompiledRules.add(new Hypothesis<>(null, null, null, compiledResults.term1, compiledResults.term2, BigInteger.ONE));
removeUnusedVariables(newCompiledRules, values);
ConcreteAppBuilder builder = factory.appBuilder(factory.ref(function.getRef()))
.app(makeData(classCall, factory.core(instance), compiler.getKind(), values), false)
.app(equationsToConcrete(compiledRules))
.app(compiledResults.term1.concrete())
.app(compiledResults.term2.concrete());
for (int i = 0; i < solutions.size(); i++) {
dropUnusedHypotheses(combinedSolutions, solutions.get(i).subList(2, solutions.get(i).size()));
dropUnusedHypotheses(combinedSolutions, rulesSet.get(i).subList(2, rulesSet.get(i).size()));
builder.app(certificateToConcrete(solutions.get(i), rulesSet.get(i)));
}
return typechecker.typecheck(builder.app(witnessesToConcrete(compiledRules)).build(), null);
}
}
} else {
List<Hypothesis<CoreExpression>> finalRules = rules;
boolean[] solutionFound = new boolean[] { false };
TypedExpression result = typechecker.withCurrentState(tc -> {
List<List<Hypothesis<CoreExpression>>> rulesSet = new ArrayList<>();
for (Hypothesis<CoreExpression> rule : finalRules) {
RingKind kind = BaseTermCompiler.getTermCompilerKind(rule.instance, ext.equationMeta);
boolean found = false;
for (int i = 0; i < rulesSet.size(); i++) {
List<Hypothesis<CoreExpression>> newRules = rulesSet.get(i);
Boolean compareResult = Utils.tryWithSavedState(tc, tc2 -> tc2.compare(rule.instance, newRules.get(0).instance, CMP.EQ, marker, false, true, false));
if (compareResult != null && compareResult) {
newRules.add(rule);
found = true;
break;
} else if (kind != RingKind.NAT) {
RingKind newKind = BaseTermCompiler.getTermCompilerKind(newRules.get(0).instance, ext.equationMeta);
if (kind == RingKind.NONE && newKind != RingKind.RAT || newKind == RingKind.NONE || kind.ordinal() > newKind.ordinal() && !(newKind == RingKind.RAT && kind == RingKind.NONE)) {
found = true;
List<Hypothesis<CoreExpression>> newRules2 = new ArrayList<>(newRules.size() + 1);
boolean remove = true;
for (Hypothesis<CoreExpression> newRule : newRules) {
Hypothesis<CoreExpression> newRule2 = convertHypothesis(newRule, rule.instance, newKind, kind);
if (newRule2 != null) {
newRules2.add(newRule2);
} else {
remove = false;
}
}
newRules2.add(rule);
rulesSet.add(newRules2);
if (remove) {
rulesSet.remove(i);
}
break;
}
}
}
if (!found) {
List<Hypothesis<CoreExpression>> newRules = new ArrayList<>();
newRules.add(rule);
rulesSet.add(newRules);
}
}
for (List<Hypothesis<CoreExpression>> equations : rulesSet) {
TypedExpression instance = equations.get(0).instance.computeTyped();
CoreClassCallExpression classCall = Utils.getClassCall(instance.getType());
TermCompiler compiler = makeTermCompiler(instance, classCall);
if (compiler == null) continue;
List<Hypothesis<CompiledTerm>> compiledEquations = new ArrayList<>();
compileHypotheses(compiler, equations, compiledEquations);
if (compiler.isNat() || compiler.isInt()) {
makeZeroLessVar(instance.getExpression(), compiler, compiledEquations);
}
List<Equation<CompiledTerm>> compiledEquations1 = new ArrayList<>(compiledEquations.size() + 1);
compiledEquations1.add(makeZeroLessOne(instance.getExpression()));
compiledEquations1.addAll(compiledEquations);
List<BigInteger> solution = solveEquations(compiledEquations1, compiler.getNumberOfVariables());
if (solution != null) {
List<BigInteger> subList = solution.subList(1, solution.size());
dropUnusedHypotheses(subList, compiledEquations);
dropUnusedHypotheses(subList, compiledEquations1.subList(1, compiledEquations1.size()));
dropUnusedHypotheses(subList, subList);
solutionFound[0] = true;
List<CoreExpression> values = compiler.getValues().getValues();
removeUnusedVariables(compiledEquations, values);
return tc.typecheck(factory.appBuilder(factory.ref(ext.linearSolverMeta.solveContrProblem.getRef()))
.app(makeData(classCall, factory.core(instance), compiler.getKind(), values), false)
.app(equationsToConcrete(compiledEquations))
.app(certificateToConcrete(solution, compiledEquations1))
.app(witnessesToConcrete(compiledEquations))
.build(), null);
}
}
return null;
});
if (solutionFound[0]) {
return result;
}
}
errorReporter.report(new LinearSolverError(typechecker.getExpressionPrettifier(), resultEquation, rules, marker));
return null;
}