private static Expression propagate()

in x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEquals.java [196:358]


    private static Expression propagate(Or or, LogicalOptimizerContext ctx) {
        List<Expression> exps = new ArrayList<>();
        List<Equals> equals = new ArrayList<>(); // foldable right term Equals
        List<NotEquals> notEquals = new ArrayList<>(); // foldable right term NotEquals
        List<Range> ranges = new ArrayList<>();
        List<BinaryComparison> inequalities = new ArrayList<>(); // foldable right term (=limit) BinaryComparision

        // split expressions by type
        for (Expression ex : Predicates.splitOr(or)) {
            if (ex instanceof Equals eq) {
                if (eq.right().foldable()) {
                    equals.add(eq);
                } else {
                    exps.add(ex);
                }
            } else if (ex instanceof NotEquals neq) {
                if (neq.right().foldable()) {
                    notEquals.add(neq);
                } else {
                    exps.add(ex);
                }
            } else if (ex instanceof Range) {
                ranges.add((Range) ex);
            } else if (ex instanceof BinaryComparison bc) {
                if (bc.right().foldable()) {
                    inequalities.add(bc);
                } else {
                    exps.add(ex);
                }
            } else {
                exps.add(ex);
            }
        }

        boolean updated = false; // has the expression been modified?

        // evaluate the impact of each Equal over the different types of Expressions
        for (Iterator<Equals> iterEq = equals.iterator(); iterEq.hasNext();) {
            Equals eq = iterEq.next();
            Object eqValue = eq.right().fold(ctx.foldCtx());
            boolean removeEquals = false;

            // Equals OR NotEquals
            for (NotEquals neq : notEquals) {
                if (eq.left().semanticEquals(neq.left())) { // a = 2 OR a != ? -> ...
                    Integer comp = BinaryComparison.compare(eqValue, neq.right().fold(ctx.foldCtx()));
                    if (comp != null) {
                        if (comp == 0) { // a = 2 OR a != 2 -> TRUE
                            return TRUE;
                        } else { // a = 2 OR a != 5 -> a != 5
                            removeEquals = true;
                            break;
                        }
                    }
                }
            }
            if (removeEquals) {
                iterEq.remove();
                updated = true;
                continue;
            }

            // Equals OR Range
            /*
            NB: this loop is probably dead code. There's no syntax for ranges, so the parser never produces them.  This
            rule can create ranges, but only in this loop, which iterates over the existing ranges. In short,
            ranges.size() should always be zero at this point.
             */
            for (int i = 0; i < ranges.size(); i++) { // might modify list, so use index loop
                Range range = ranges.get(i);
                if (eq.left().semanticEquals(range.value())) {
                    Integer lowerComp = range.lower().foldable()
                        ? BinaryComparison.compare(eqValue, range.lower().fold(ctx.foldCtx()))
                        : null;
                    Integer upperComp = range.upper().foldable()
                        ? BinaryComparison.compare(eqValue, range.upper().fold(ctx.foldCtx()))
                        : null;

                    if (lowerComp != null && lowerComp == 0) {
                        if (range.includeLower() == false) { // a = 2 OR 2 < a < ? -> 2 <= a < ?
                            ranges.set(
                                i,
                                new Range(
                                    range.source(),
                                    range.value(),
                                    range.lower(),
                                    true,
                                    range.upper(),
                                    range.includeUpper(),
                                    range.zoneId()
                                )
                            );
                        } // else : a = 2 OR 2 <= a < ? -> 2 <= a < ?
                        removeEquals = true; // update range with lower equality instead or simply superfluous
                        break;
                    } else if (upperComp != null && upperComp == 0) {
                        if (range.includeUpper() == false) { // a = 2 OR ? < a < 2 -> ? < a <= 2
                            ranges.set(
                                i,
                                new Range(
                                    range.source(),
                                    range.value(),
                                    range.lower(),
                                    range.includeLower(),
                                    range.upper(),
                                    true,
                                    range.zoneId()
                                )
                            );
                        } // else : a = 2 OR ? < a <= 2 -> ? < a <= 2
                        removeEquals = true; // update range with upper equality instead
                        break;
                    } else if (lowerComp != null && upperComp != null) {
                        if (0 < lowerComp && upperComp < 0) { // a = 2 OR 1 < a < 3
                            removeEquals = true; // equality is superfluous
                            break;
                        }
                    }
                }
            }
            if (removeEquals) {
                iterEq.remove();
                updated = true;
                continue;
            }

            // Equals OR Inequality
            for (int i = 0; i < inequalities.size(); i++) {
                BinaryComparison bc = inequalities.get(i);
                if (eq.left().semanticEquals(bc.left())) {
                    Integer comp = BinaryComparison.compare(eqValue, bc.right().fold(ctx.foldCtx()));
                    if (comp != null) {
                        if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) {
                            if (comp < 0) { // a = 1 OR a > 2 -> nop
                                continue;
                            } else if (comp == 0 && bc instanceof GreaterThan) { // a = 2 OR a > 2 -> a >= 2
                                inequalities.set(i, new GreaterThanOrEqual(bc.source(), bc.left(), bc.right(), bc.zoneId()));
                            } // else (0 < comp || bc instanceof GreaterThanOrEqual) :
                              // a = 3 OR a > 2 -> a > 2; a = 2 OR a => 2 -> a => 2

                            removeEquals = true; // update range with equality instead or simply superfluous
                            break;
                        } else if (bc instanceof LessThan || bc instanceof LessThanOrEqual) {
                            if (comp > 0) { // a = 2 OR a < 1 -> nop
                                continue;
                            }
                            if (comp == 0 && bc instanceof LessThan) { // a = 2 OR a < 2 -> a <= 2
                                inequalities.set(i, new LessThanOrEqual(bc.source(), bc.left(), bc.right(), bc.zoneId()));
                            } // else (comp < 0 || bc instanceof LessThanOrEqual) : a = 2 OR a < 3 -> a < 3; a = 2 OR a <= 2 -> a <= 2
                            removeEquals = true; // update range with equality instead or simply superfluous
                            break;
                        }
                    }
                }
            }
            if (removeEquals) {
                iterEq.remove();
                updated = true;
            }
        }

        return updated ? Predicates.combineOr(CollectionUtils.combine(exps, equals, notEquals, inequalities, ranges)) : or;
    }