in base/src/main/java/org/arend/typechecking/patternmatching/ElimTypechecking.java [780:1135]
private ElimTree clausesToElimTree(List<ExtElimClause> clauses, int argsStackSize, int numberOfIntervals) {
try (Utils.ContextSaver ignored = new Utils.ContextSaver(myContext)) {
int index = 0;
loop:
for (; index < clauses.get(0).getPatterns().size(); index++) {
for (ExtElimClause clause : clauses) {
if (!(clause.getPatterns().get(index) instanceof BindingPattern)) {
if (clauses.get(0).getPatterns().get(index) instanceof BindingPattern && clause.getPatterns().get(index) instanceof ConstructorPattern) {
Definition definition = clause.getPatterns().get(index).getDefinition();
if (definition == Prelude.LEFT || definition == Prelude.RIGHT) {
final int finalIndex = index;
clauses = clauses.stream().filter(clauseData1 -> clauseData1.getPatterns().get(finalIndex) instanceof BindingPattern).collect(Collectors.toList());
continue loop;
}
}
break loop;
}
}
}
// If all patterns are variables
if (index == clauses.get(0).getPatterns().size()) {
ExtElimClause clause = clauses.get(0);
myUnusedClauses.remove(clause.index);
List<Integer> indices = clause.argIndices;
if (index > clause.numberOfFakeVars) {
indices = new ArrayList<>(indices);
for (int i = clause.numberOfFakeVars; i < index; i++) {
indices.add(argsStackSize + i);
}
}
return new LeafElimTree(index, isConsequent(indices) ? null : indices, clause.index);
}
for (int i = 0; i < index; i++) {
myContext.push(new Util.PatternClauseElem(clauses.get(0).getPatterns().get(i)));
}
ExtElimClause conClause = null;
for (ExtElimClause clause : clauses) {
Pattern pattern = clause.getPatterns().get(index);
if (pattern instanceof EmptyPattern) {
myUnusedClauses.remove(clause.index);
return new BranchElimTree(index, false);
}
if (conClause == null && pattern instanceof ConstructorPattern) {
conClause = clause;
}
}
assert conClause != null;
ConstructorExpressionPattern someConPattern = (ConstructorExpressionPattern) conClause.getPatterns().get(index);
Expression arrayElementsType = someConPattern.getArrayElementsType();
Expression arrayLength = someConPattern.getArrayLength();
List<ConCallExpression> conCalls = null;
List<BranchKey> branchKeys;
DataDefinition dataType;
if (someConPattern.getDefinition() instanceof Constructor constructor) {
dataType = constructor.getDataType();
if (dataType.hasIndexedConstructors() || dataType == Prelude.PATH) {
DataCallExpression dataCall;
if (constructor == Prelude.FIN_ZERO || constructor == Prelude.FIN_SUC) {
dataCall = Fin(Suc(((ConCallExpression) someConPattern.getDataExpression()).getDataTypeArguments().get(0).subst(conClause.substitution)));
} else {
dataCall = (DataCallExpression) someConPattern.getDataExpression().subst(conClause.substitution).getType();
}
conCalls = dataCall.getMatchedConstructors();
if (conCalls == null) {
if (myErrorReporter != null) myErrorReporter.report(new ImpossibleEliminationError(dataCall, getClause(conClause.index, someConPattern), null, null, null, null, null));
myOK = false;
return null;
}
branchKeys = new ArrayList<>(conCalls.size());
for (ConCallExpression conCall : conCalls) {
branchKeys.add(conCall.getDefinition());
}
} else {
branchKeys = new ArrayList<>(dataType.getConstructors());
}
} else if (someConPattern.getDefinition() == Prelude.EMPTY_ARRAY || someConPattern.getDefinition() == Prelude.ARRAY_CONS) {
Boolean empty = someConPattern.isArrayEmpty();
branchKeys = new ArrayList<>(2);
if (empty == null || empty.equals(Boolean.TRUE)) {
branchKeys.add(new ArrayConstructor(true, arrayElementsType != null, arrayLength != null));
}
if (empty == null || empty.equals(Boolean.FALSE)) {
branchKeys.add(new ArrayConstructor(false, arrayElementsType != null, arrayLength != null));
}
dataType = null;
} else {
if (someConPattern.getDataExpression() instanceof ClassCallExpression classCall) {
branchKeys = Collections.singletonList(new ClassConstructor(classCall.getDefinition(), classCall.getLevels(), classCall.getImplementedHere().keySet()));
} else if (someConPattern.getDataExpression() instanceof SigmaExpression) {
Set<Integer> propertyIndices = Collections.emptySet();
int i = 0;
for (DependentLink link = ((SigmaExpression) someConPattern.getDataExpression()).getParameters(); link.hasNext(); link = link.getNext(), i++) {
if (link.isProperty()) {
if (propertyIndices.isEmpty()) propertyIndices = new HashSet<>();
propertyIndices.add(i);
}
}
branchKeys = Collections.singletonList(new TupleConstructor(someConPattern.getLength(), propertyIndices));
} else {
assert someConPattern.getDefinition() == Prelude.IDP;
branchKeys = Collections.singletonList(new IdpConstructor());
}
dataType = null;
}
if (dataType == Prelude.INTERVAL) {
if (myErrorReporter != null) myErrorReporter.report(new TypecheckingError("Pattern matching on the interval is not allowed here", getClause(conClause.index, someConPattern)));
myOK = false;
return null;
}
if (dataType != null && dataType.isSquashed() && myErrorReporter != null) {
if (myActualLevel != null && !Level.compare(myActualLevel, dataType.getSort().getHLevel().add(myActualLevelSub), CMP.LE, myEquations, getClause(conClause.index, someConPattern))) {
myErrorReporter.report(new SquashedDataError(dataType, myActualLevel, myActualLevelSub, getClause(conClause.index, someConPattern)));
}
boolean ok = !dataType.isTruncated() || myLevel != null && myLevel <= dataType.getTruncatedLevel() + 1;
if (!ok) {
Expression type = myExpectedType.getType();
if (type != null) {
type = type.normalize(NormalizationMode.WHNF);
UniverseExpression universe = type.cast(UniverseExpression.class);
if (universe != null) {
ok = Level.compare(universe.getSort().getHLevel(), dataType.getSort().getHLevel(), CMP.LE, myEquations, getClause(conClause.index, someConPattern));
} else {
InferenceLevelVariable pl = new InferenceLevelVariable(LevelVariable.LvlType.PLVL, false, getClause(conClause.index, someConPattern));
myEquations.addVariable(pl);
ok = type.isLessOrEquals(new UniverseExpression(new Sort(new Level(pl), dataType.getSort().getHLevel())), myEquations, getClause(conClause.index, someConPattern));
}
}
}
if (!ok) {
myErrorReporter.report(new TruncatedDataError(dataType, myExpectedType, getClause(conClause.index, someConPattern)));
myOK = false;
}
}
if (myLevel != null && !branchKeys.isEmpty() && !(branchKeys.get(0) instanceof SingleConstructor)) {
//noinspection ConstantConditions
branchKeys.removeIf(key -> numberOfIntervals + (key.getBody() instanceof IntervalElim ? ((IntervalElim) key.getBody()).getNumberOfTotalElim() : 0) > myLevel);
}
boolean hasVars = false;
Map<BranchKey, List<ExtElimClause>> branchKeyMap = new LinkedHashMap<>();
for (ExtElimClause clause : clauses) {
if (clause.getPatterns().get(index) instanceof BindingPattern) {
hasVars = true;
for (BranchKey key : branchKeys) {
branchKeyMap.computeIfAbsent(key, k -> new ArrayList<>()).add(clause);
}
} else {
Definition def = clause.getPatterns().get(index).getDefinition();
BranchKey key = def instanceof Constructor ? (Constructor) def : def == Prelude.EMPTY_ARRAY || def == Prelude.ARRAY_CONS ? new ArrayConstructor((DConstructor) def, arrayElementsType != null, arrayLength != null) : null;
if (key == null && !branchKeys.isEmpty() && branchKeys.get(0) instanceof SingleConstructor) {
key = branchKeys.get(0);
}
if (key != null) {
branchKeyMap.computeIfAbsent(key, k -> new ArrayList<>()).add(clause);
}
}
}
if (myMode.checkCoverage() && !hasVars) {
for (BranchKey key : branchKeys) {
if (!branchKeyMap.containsKey(key)) {
List<Util.ClauseElem> context = new ArrayList<>(myContext);
context.add(Util.makeDataClauseElem(key, someConPattern));
for (DependentLink link = key.getParameters(someConPattern); link.hasNext(); link = link.getNext()) {
context.add(new Util.PatternClauseElem(new BindingPattern(link)));
}
addMissingClause(context, false);
}
}
}
BranchElimTree branchElimTree = new BranchElimTree(index, hasVars);
for (BranchKey branchKey : branchKeys) {
List<ExtElimClause> conClauseList = branchKeyMap.get(branchKey);
if (conClauseList == null) {
continue;
}
myContext.push(Util.makeDataClauseElem(branchKey, someConPattern));
for (int i = 0; i < conClauseList.size(); i++) {
ExtElimClause clause = conClauseList.get(i);
List<Integer> indices = new ArrayList<>(clause.argIndices);
for (int j = clause.numberOfFakeVars; j < index; j++) {
indices.add(argsStackSize + j);
}
int numberOfFakeVars = Math.max(clause.numberOfFakeVars - index, 0);
List<ExpressionPattern> patterns = new ArrayList<>();
List<ExpressionPattern> oldPatterns = clause.getPatterns();
ExprSubstitution newSubstitution;
if (oldPatterns.get(index) instanceof ConstructorExpressionPattern) {
patterns.addAll(oldPatterns.get(index).getSubPatterns());
newSubstitution = conClauseList.get(i).substitution;
} else {
Expression substExpr;
DependentLink conParameters;
List<Expression> arguments = new ArrayList<>();
if (conCalls != null) {
ConCallExpression conCall = null;
for (ConCallExpression conCall1 : conCalls) {
if (conCall1.getDefinition() == branchKey) {
conCall = conCall1;
break;
}
}
assert conCall != null;
List<Expression> dataTypesArgs = conCall.getDataTypeArguments();
substExpr = ConCallExpression.make(conCall.getDefinition(), conCall.getLevels(), dataTypesArgs, arguments);
conParameters = DependentLink.Helper.subst(branchKey.getParameters(someConPattern), DependentLink.Helper.toSubstitution(conCall.getDefinition().getDataTypeParameters(), dataTypesArgs));
} else if (branchKey instanceof SingleConstructor) {
conParameters = someConPattern.getParameters();
Expression someExpr = someConPattern.getDataExpression();
if (someExpr instanceof ClassCallExpression classCall) {
Map<ClassField, Expression> implementations = new LinkedHashMap<>();
DependentLink link = conParameters;
for (ClassField field : classCall.getDefinition().getNotImplementedFields()) {
if (!classCall.isImplemented(field)) {
implementations.put(field, new ReferenceExpression(link));
link = link.getNext();
}
}
substExpr = new NewExpression(null, new ClassCallExpression(classCall.getDefinition(), classCall.getLevels(), implementations, Sort.PROP, UniverseKind.NO_UNIVERSES));
} else if (someExpr instanceof SigmaExpression) {
substExpr = new TupleExpression(arguments, (SigmaExpression) someExpr);
conParameters = DependentLink.Helper.copy(conParameters);
} else if (someExpr instanceof FunCallExpression) {
substExpr = someExpr;
} else {
throw new IllegalStateException();
}
} else if (branchKey instanceof Constructor constructor) {
List<Expression> dataTypesArgs = new ArrayList<>();
for (Expression dataTypeArg : someConPattern.getDataTypeArguments()) {
dataTypesArgs.add(dataTypeArg.subst(conClause.substitution));
}
substExpr = ConCallExpression.make(constructor, someConPattern.getLevels(), dataTypesArgs, arguments);
conParameters = DependentLink.Helper.subst(constructor.getParameters(), DependentLink.Helper.toSubstitution(constructor.getDataTypeParameters(), someConPattern.getDataTypeArguments()));
} else if (branchKey instanceof ArrayConstructor) {
if (arrayLength != null) {
arguments.add(arrayLength);
}
if (arrayElementsType != null) {
arguments.add(arrayElementsType);
}
substExpr = FunCallExpression.make(((ArrayConstructor) branchKey).getConstructor(), someConPattern.getLevels(), arguments);
conParameters = branchKey.getParameters(someConPattern);
} else {
throw new IllegalStateException();
}
if (numberOfFakeVars == 0) {
indices.add(argsStackSize + index);
} else {
numberOfFakeVars--;
}
for (DependentLink link = conParameters; link.hasNext(); link = link.getNext()) {
patterns.add(new BindingPattern(link));
arguments.add(new ReferenceExpression(link));
numberOfFakeVars++;
}
newSubstitution = new ExprSubstitution(conClauseList.get(i).substitution);
Binding patternBinding = ((BindingPattern) oldPatterns.get(index)).getBinding();
newSubstitution.addSubst(patternBinding, substExpr);
if (branchKey == Prelude.ZERO && index + 2 < oldPatterns.size()) {
ExpressionPattern pattern = oldPatterns.get(index + 2);
if (pattern instanceof BindingPattern) {
Binding binding = pattern.getBinding();
Expression type = binding.getTypeExpr().normalize(NormalizationMode.WHNF);
if (type instanceof ClassCallExpression classCall && classCall.getDefinition() == Prelude.DEP_ARRAY) {
Expression length = classCall.getImplementationHere(Prelude.ARRAY_LENGTH, new ReferenceExpression(binding));
if (length != null) {
length = length.normalize(NormalizationMode.WHNF);
if (length instanceof ReferenceExpression && ((ReferenceExpression) length).getBinding() == patternBinding) {
Expression elementsType = classCall.getImplementationHere(Prelude.ARRAY_ELEMENTS_TYPE, new ReferenceExpression(binding));
newSubstitution.addSubst(binding, new FunCallExpression(Prelude.EMPTY_ARRAY, classCall.getLevels(), null, elementsType != null ? elementsType : FieldCallExpression.make(Prelude.ARRAY_ELEMENTS_TYPE, new ReferenceExpression(classCall.getThisBinding()))));
}
}
}
}
}
}
patterns.addAll(oldPatterns.subList(index + 1, oldPatterns.size()));
conClauseList.set(i, new ExtElimClause(patterns, clause.getExpression(), clause.index, indices, numberOfFakeVars, newSubstitution));
}
ElimTree elimTree = clausesToElimTree(conClauseList, argsStackSize + index + (hasVars ? 1 : 0), myLevel == null ? 0 : numberOfIntervals + (branchKey.getBody() instanceof IntervalElim ? ((IntervalElim) branchKey.getBody()).getNumberOfTotalElim() : 0));
if (elimTree == null) {
myOK = false;
} else {
branchElimTree.addChild(branchKey == Prelude.FIN_ZERO ? Prelude.ZERO : branchKey == Prelude.FIN_SUC ? Prelude.SUC : branchKey, elimTree);
}
myContext.pop();
// If we match on a variable and the constructor has conditions,
// we need to check that it isn't mapped to a clause with a variable
// unless constructors to which the current one evaluates is also mapped to the same clause.
// We need this because condition checker doesn't check clauses with variables.
if (hasVars && dataType != null && branchKey instanceof Constructor && branchKey.getBody() != null && myErrorReporter != null) {
Set<Integer> indices = new HashSet<>();
collectClauseIndices(elimTree, indices);
for (ExtElimClause clause : clauses) {
if (!(clause.getPatterns().get(index) instanceof BindingPattern)) {
indices.remove(clause.index);
}
}
if (!indices.isEmpty()) {
Set<Integer> depIndices = new HashSet<>();
Set<Constructor> depConstructors = new HashSet<>();
collectConstructors(dataType, branchKey.getBody(), depConstructors);
boolean ok = true;
for (Constructor depConstructor : depConstructors) {
ElimTree depElimTree = branchElimTree.getChild(depConstructor);
if (depElimTree == null) {
ok = false;
} else {
collectClauseIndices(depElimTree, depIndices);
}
}
if (ok && !depIndices.isEmpty()) {
if (indices.size() > 1) {
ok = false;
} else {
ok = depIndices.size() == 1 && depIndices.iterator().next().equals(indices.iterator().next());
}
}
if (!ok) {
Concrete.SourceNode sourceNode;
if (myClauses != null) {
Concrete.FunctionClause functionClause = myClauses.get(indices.iterator().next());
sourceNode = index < functionClause.getPatterns().size() ? functionClause.getPatterns().get(index) : functionClause;
} else {
sourceNode = mySourceNode;
}
myErrorReporter.report(new HigherConstructorMatchingError((Constructor) branchKey, sourceNode));
}
}
}
}
return branchElimTree;
}
}