private ElimTree clausesToElimTree()

in base/src/main/java/org/arend/typechecking/patternmatching/ElimTypechecking.java [780:1135]


  private ElimTree clausesToElimTree(List<ExtElimClause> clauses, int argsStackSize, int numberOfIntervals) {
    try (Utils.ContextSaver ignored = new Utils.ContextSaver(myContext)) {
      int index = 0;
      loop:
      for (; index < clauses.get(0).getPatterns().size(); index++) {
        for (ExtElimClause clause : clauses) {
          if (!(clause.getPatterns().get(index) instanceof BindingPattern)) {
            if (clauses.get(0).getPatterns().get(index) instanceof BindingPattern && clause.getPatterns().get(index) instanceof ConstructorPattern) {
              Definition definition = clause.getPatterns().get(index).getDefinition();
              if (definition == Prelude.LEFT || definition == Prelude.RIGHT) {
                final int finalIndex = index;
                clauses = clauses.stream().filter(clauseData1 -> clauseData1.getPatterns().get(finalIndex) instanceof BindingPattern).collect(Collectors.toList());
                continue loop;
              }
            }
            break loop;
          }
        }
      }

      // If all patterns are variables
      if (index == clauses.get(0).getPatterns().size()) {
        ExtElimClause clause = clauses.get(0);
        myUnusedClauses.remove(clause.index);
        List<Integer> indices = clause.argIndices;
        if (index > clause.numberOfFakeVars) {
          indices = new ArrayList<>(indices);
          for (int i = clause.numberOfFakeVars; i < index; i++) {
            indices.add(argsStackSize + i);
          }
        }
        return new LeafElimTree(index, isConsequent(indices) ? null : indices, clause.index);
      }

      for (int i = 0; i < index; i++) {
        myContext.push(new Util.PatternClauseElem(clauses.get(0).getPatterns().get(i)));
      }

      ExtElimClause conClause = null;
      for (ExtElimClause clause : clauses) {
        Pattern pattern = clause.getPatterns().get(index);
        if (pattern instanceof EmptyPattern) {
          myUnusedClauses.remove(clause.index);
          return new BranchElimTree(index, false);
        }
        if (conClause == null && pattern instanceof ConstructorPattern) {
          conClause = clause;
        }
      }

      assert conClause != null;
      ConstructorExpressionPattern someConPattern = (ConstructorExpressionPattern) conClause.getPatterns().get(index);
      Expression arrayElementsType = someConPattern.getArrayElementsType();
      Expression arrayLength = someConPattern.getArrayLength();
      List<ConCallExpression> conCalls = null;
      List<BranchKey> branchKeys;
      DataDefinition dataType;
      if (someConPattern.getDefinition() instanceof Constructor constructor) {
        dataType = constructor.getDataType();
        if (dataType.hasIndexedConstructors() || dataType == Prelude.PATH) {
          DataCallExpression dataCall;
          if (constructor == Prelude.FIN_ZERO || constructor == Prelude.FIN_SUC) {
            dataCall = Fin(Suc(((ConCallExpression) someConPattern.getDataExpression()).getDataTypeArguments().get(0).subst(conClause.substitution)));
          } else {
            dataCall = (DataCallExpression) someConPattern.getDataExpression().subst(conClause.substitution).getType();
          }
          conCalls = dataCall.getMatchedConstructors();
          if (conCalls == null) {
            if (myErrorReporter != null) myErrorReporter.report(new ImpossibleEliminationError(dataCall, getClause(conClause.index, someConPattern), null, null, null, null, null));
            myOK = false;
            return null;
          }
          branchKeys = new ArrayList<>(conCalls.size());
          for (ConCallExpression conCall : conCalls) {
            branchKeys.add(conCall.getDefinition());
          }
        } else {
          branchKeys = new ArrayList<>(dataType.getConstructors());
        }
      } else if (someConPattern.getDefinition() == Prelude.EMPTY_ARRAY || someConPattern.getDefinition() == Prelude.ARRAY_CONS) {
        Boolean empty = someConPattern.isArrayEmpty();
        branchKeys = new ArrayList<>(2);
        if (empty == null || empty.equals(Boolean.TRUE)) {
          branchKeys.add(new ArrayConstructor(true, arrayElementsType != null, arrayLength != null));
        }
        if (empty == null || empty.equals(Boolean.FALSE)) {
          branchKeys.add(new ArrayConstructor(false, arrayElementsType != null, arrayLength != null));
        }
        dataType = null;
      } else {
        if (someConPattern.getDataExpression() instanceof ClassCallExpression classCall) {
          branchKeys = Collections.singletonList(new ClassConstructor(classCall.getDefinition(), classCall.getLevels(), classCall.getImplementedHere().keySet()));
        } else if (someConPattern.getDataExpression() instanceof SigmaExpression) {
          Set<Integer> propertyIndices = Collections.emptySet();
          int i = 0;
          for (DependentLink link = ((SigmaExpression) someConPattern.getDataExpression()).getParameters(); link.hasNext(); link = link.getNext(), i++) {
            if (link.isProperty()) {
              if (propertyIndices.isEmpty()) propertyIndices = new HashSet<>();
              propertyIndices.add(i);
            }
          }
          branchKeys = Collections.singletonList(new TupleConstructor(someConPattern.getLength(), propertyIndices));
        } else {
          assert someConPattern.getDefinition() == Prelude.IDP;
          branchKeys = Collections.singletonList(new IdpConstructor());
        }
        dataType = null;
      }

      if (dataType == Prelude.INTERVAL) {
        if (myErrorReporter != null) myErrorReporter.report(new TypecheckingError("Pattern matching on the interval is not allowed here", getClause(conClause.index, someConPattern)));
        myOK = false;
        return null;
      }

      if (dataType != null && dataType.isSquashed() && myErrorReporter != null) {
        if (myActualLevel != null && !Level.compare(myActualLevel, dataType.getSort().getHLevel().add(myActualLevelSub), CMP.LE, myEquations, getClause(conClause.index, someConPattern))) {
          myErrorReporter.report(new SquashedDataError(dataType, myActualLevel, myActualLevelSub, getClause(conClause.index, someConPattern)));
        }

        boolean ok = !dataType.isTruncated() || myLevel != null && myLevel <= dataType.getTruncatedLevel() + 1;
        if (!ok) {
          Expression type = myExpectedType.getType();
          if (type != null) {
            type = type.normalize(NormalizationMode.WHNF);
            UniverseExpression universe = type.cast(UniverseExpression.class);
            if (universe != null) {
              ok = Level.compare(universe.getSort().getHLevel(), dataType.getSort().getHLevel(), CMP.LE, myEquations, getClause(conClause.index, someConPattern));
            } else {
              InferenceLevelVariable pl = new InferenceLevelVariable(LevelVariable.LvlType.PLVL, false, getClause(conClause.index, someConPattern));
              myEquations.addVariable(pl);
              ok = type.isLessOrEquals(new UniverseExpression(new Sort(new Level(pl), dataType.getSort().getHLevel())), myEquations, getClause(conClause.index, someConPattern));
            }
          }
        }
        if (!ok) {
          myErrorReporter.report(new TruncatedDataError(dataType, myExpectedType, getClause(conClause.index, someConPattern)));
          myOK = false;
        }
      }

      if (myLevel != null && !branchKeys.isEmpty() && !(branchKeys.get(0) instanceof SingleConstructor)) {
        //noinspection ConstantConditions
        branchKeys.removeIf(key -> numberOfIntervals + (key.getBody() instanceof IntervalElim ? ((IntervalElim) key.getBody()).getNumberOfTotalElim() : 0) > myLevel);
      }

      boolean hasVars = false;
      Map<BranchKey, List<ExtElimClause>> branchKeyMap = new LinkedHashMap<>();
      for (ExtElimClause clause : clauses) {
        if (clause.getPatterns().get(index) instanceof BindingPattern) {
          hasVars = true;
          for (BranchKey key : branchKeys) {
            branchKeyMap.computeIfAbsent(key, k -> new ArrayList<>()).add(clause);
          }
        } else {
          Definition def = clause.getPatterns().get(index).getDefinition();
          BranchKey key = def instanceof Constructor ? (Constructor) def : def == Prelude.EMPTY_ARRAY || def == Prelude.ARRAY_CONS ? new ArrayConstructor((DConstructor) def, arrayElementsType != null, arrayLength != null) : null;
          if (key == null && !branchKeys.isEmpty() && branchKeys.get(0) instanceof SingleConstructor) {
            key = branchKeys.get(0);
          }
          if (key != null) {
            branchKeyMap.computeIfAbsent(key, k -> new ArrayList<>()).add(clause);
          }
        }
      }

      if (myMode.checkCoverage() && !hasVars) {
        for (BranchKey key : branchKeys) {
          if (!branchKeyMap.containsKey(key)) {
            List<Util.ClauseElem> context = new ArrayList<>(myContext);
            context.add(Util.makeDataClauseElem(key, someConPattern));
            for (DependentLink link = key.getParameters(someConPattern); link.hasNext(); link = link.getNext()) {
              context.add(new Util.PatternClauseElem(new BindingPattern(link)));
            }
            addMissingClause(context, false);
          }
        }
      }

      BranchElimTree branchElimTree = new BranchElimTree(index, hasVars);
      for (BranchKey branchKey : branchKeys) {
        List<ExtElimClause> conClauseList = branchKeyMap.get(branchKey);
        if (conClauseList == null) {
          continue;
        }
        myContext.push(Util.makeDataClauseElem(branchKey, someConPattern));

        for (int i = 0; i < conClauseList.size(); i++) {
          ExtElimClause clause = conClauseList.get(i);
          List<Integer> indices = new ArrayList<>(clause.argIndices);
          for (int j = clause.numberOfFakeVars; j < index; j++) {
            indices.add(argsStackSize + j);
          }
          int numberOfFakeVars = Math.max(clause.numberOfFakeVars - index, 0);

          List<ExpressionPattern> patterns = new ArrayList<>();
          List<ExpressionPattern> oldPatterns = clause.getPatterns();
          ExprSubstitution newSubstitution;
          if (oldPatterns.get(index) instanceof ConstructorExpressionPattern) {
            patterns.addAll(oldPatterns.get(index).getSubPatterns());
            newSubstitution = conClauseList.get(i).substitution;
          } else {
            Expression substExpr;
            DependentLink conParameters;
            List<Expression> arguments = new ArrayList<>();
            if (conCalls != null) {
              ConCallExpression conCall = null;
              for (ConCallExpression conCall1 : conCalls) {
                if (conCall1.getDefinition() == branchKey) {
                  conCall = conCall1;
                  break;
                }
              }
              assert conCall != null;
              List<Expression> dataTypesArgs = conCall.getDataTypeArguments();
              substExpr = ConCallExpression.make(conCall.getDefinition(), conCall.getLevels(), dataTypesArgs, arguments);
              conParameters = DependentLink.Helper.subst(branchKey.getParameters(someConPattern), DependentLink.Helper.toSubstitution(conCall.getDefinition().getDataTypeParameters(), dataTypesArgs));
            } else if (branchKey instanceof SingleConstructor) {
              conParameters = someConPattern.getParameters();
              Expression someExpr = someConPattern.getDataExpression();
              if (someExpr instanceof ClassCallExpression classCall) {
                Map<ClassField, Expression> implementations = new LinkedHashMap<>();
                DependentLink link = conParameters;
                for (ClassField field : classCall.getDefinition().getNotImplementedFields()) {
                  if (!classCall.isImplemented(field)) {
                    implementations.put(field, new ReferenceExpression(link));
                    link = link.getNext();
                  }
                }
                substExpr = new NewExpression(null, new ClassCallExpression(classCall.getDefinition(), classCall.getLevels(), implementations, Sort.PROP, UniverseKind.NO_UNIVERSES));
              } else if (someExpr instanceof SigmaExpression) {
                substExpr = new TupleExpression(arguments, (SigmaExpression) someExpr);
                conParameters = DependentLink.Helper.copy(conParameters);
              } else if (someExpr instanceof FunCallExpression) {
                substExpr = someExpr;
              } else {
                throw new IllegalStateException();
              }
            } else if (branchKey instanceof Constructor constructor) {
              List<Expression> dataTypesArgs = new ArrayList<>();
              for (Expression dataTypeArg : someConPattern.getDataTypeArguments()) {
                dataTypesArgs.add(dataTypeArg.subst(conClause.substitution));
              }
              substExpr = ConCallExpression.make(constructor, someConPattern.getLevels(), dataTypesArgs, arguments);
              conParameters = DependentLink.Helper.subst(constructor.getParameters(), DependentLink.Helper.toSubstitution(constructor.getDataTypeParameters(), someConPattern.getDataTypeArguments()));
            } else if (branchKey instanceof ArrayConstructor) {
              if (arrayLength != null) {
                arguments.add(arrayLength);
              }
              if (arrayElementsType != null) {
                arguments.add(arrayElementsType);
              }
              substExpr = FunCallExpression.make(((ArrayConstructor) branchKey).getConstructor(), someConPattern.getLevels(), arguments);
              conParameters = branchKey.getParameters(someConPattern);
            } else {
              throw new IllegalStateException();
            }

            if (numberOfFakeVars == 0) {
              indices.add(argsStackSize + index);
            } else {
              numberOfFakeVars--;
            }
            for (DependentLink link = conParameters; link.hasNext(); link = link.getNext()) {
              patterns.add(new BindingPattern(link));
              arguments.add(new ReferenceExpression(link));
              numberOfFakeVars++;
            }

            newSubstitution = new ExprSubstitution(conClauseList.get(i).substitution);
            Binding patternBinding = ((BindingPattern) oldPatterns.get(index)).getBinding();
            newSubstitution.addSubst(patternBinding, substExpr);
            if (branchKey == Prelude.ZERO && index + 2 < oldPatterns.size()) {
              ExpressionPattern pattern = oldPatterns.get(index + 2);
              if (pattern instanceof BindingPattern) {
                Binding binding = pattern.getBinding();
                Expression type = binding.getTypeExpr().normalize(NormalizationMode.WHNF);
                if (type instanceof ClassCallExpression classCall && classCall.getDefinition() == Prelude.DEP_ARRAY) {
                  Expression length = classCall.getImplementationHere(Prelude.ARRAY_LENGTH, new ReferenceExpression(binding));
                  if (length != null) {
                    length = length.normalize(NormalizationMode.WHNF);
                    if (length instanceof ReferenceExpression && ((ReferenceExpression) length).getBinding() == patternBinding) {
                      Expression elementsType = classCall.getImplementationHere(Prelude.ARRAY_ELEMENTS_TYPE, new ReferenceExpression(binding));
                      newSubstitution.addSubst(binding, new FunCallExpression(Prelude.EMPTY_ARRAY, classCall.getLevels(), null, elementsType != null ? elementsType : FieldCallExpression.make(Prelude.ARRAY_ELEMENTS_TYPE, new ReferenceExpression(classCall.getThisBinding()))));
                    }
                  }
                }
              }
            }
          }

          patterns.addAll(oldPatterns.subList(index + 1, oldPatterns.size()));
          conClauseList.set(i, new ExtElimClause(patterns, clause.getExpression(), clause.index, indices, numberOfFakeVars, newSubstitution));
        }

        ElimTree elimTree = clausesToElimTree(conClauseList, argsStackSize + index + (hasVars ? 1 : 0), myLevel == null ? 0 : numberOfIntervals + (branchKey.getBody() instanceof IntervalElim ? ((IntervalElim) branchKey.getBody()).getNumberOfTotalElim() : 0));
        if (elimTree == null) {
          myOK = false;
        } else {
          branchElimTree.addChild(branchKey == Prelude.FIN_ZERO ? Prelude.ZERO : branchKey == Prelude.FIN_SUC ? Prelude.SUC : branchKey, elimTree);
        }

        myContext.pop();

        // If we match on a variable and the constructor has conditions,
        // we need to check that it isn't mapped to a clause with a variable
        // unless constructors to which the current one evaluates is also mapped to the same clause.
        // We need this because condition checker doesn't check clauses with variables.
        if (hasVars && dataType != null && branchKey instanceof Constructor && branchKey.getBody() != null && myErrorReporter != null) {
          Set<Integer> indices = new HashSet<>();
          collectClauseIndices(elimTree, indices);
          for (ExtElimClause clause : clauses) {
            if (!(clause.getPatterns().get(index) instanceof BindingPattern)) {
              indices.remove(clause.index);
            }
          }

          if (!indices.isEmpty()) {
            Set<Integer> depIndices = new HashSet<>();
            Set<Constructor> depConstructors = new HashSet<>();
            collectConstructors(dataType, branchKey.getBody(), depConstructors);
            boolean ok = true;
            for (Constructor depConstructor : depConstructors) {
              ElimTree depElimTree = branchElimTree.getChild(depConstructor);
              if (depElimTree == null) {
                ok = false;
              } else {
                collectClauseIndices(depElimTree, depIndices);
              }
            }

            if (ok && !depIndices.isEmpty()) {
              if (indices.size() > 1) {
                ok = false;
              } else {
                ok = depIndices.size() == 1 && depIndices.iterator().next().equals(indices.iterator().next());
              }
            }

            if (!ok) {
              Concrete.SourceNode sourceNode;
              if (myClauses != null) {
                Concrete.FunctionClause functionClause = myClauses.get(indices.iterator().next());
                sourceNode = index < functionClause.getPatterns().size() ? functionClause.getPatterns().get(index) : functionClause;
              } else {
                sourceNode = mySourceNode;
              }
              myErrorReporter.report(new HigherConstructorMatchingError((Constructor) branchKey, sourceNode));
            }
          }
        }
      }

      return branchElimTree;
    }
  }