public static Query createJoinQuery()

in lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java [142:400]


  public static Query createJoinQuery(
      String fromField,
      boolean multipleValuesPerDocument,
      String toField,
      Class<? extends Number> numericType,
      Query fromQuery,
      IndexSearcher fromSearcher,
      ScoreMode scoreMode)
      throws IOException {
    LongHashSet joinValues = new LongHashSet();
    LongFloatHashMap aggregatedScores = new LongFloatHashMap();
    LongIntHashMap occurrences = new LongIntHashMap();
    boolean needsScore = scoreMode != ScoreMode.None;
    LongFloatProcedure scoreAggregator;
    if (scoreMode == ScoreMode.Max) {
      scoreAggregator =
          (key, score) -> {
            int index = aggregatedScores.indexOf(key);
            if (index < 0) {
              aggregatedScores.indexInsert(index, key, score);
            } else {
              float currentScore = aggregatedScores.indexGet(index);
              aggregatedScores.indexReplace(index, Math.max(currentScore, score));
            }
          };
    } else if (scoreMode == ScoreMode.Min) {
      scoreAggregator =
          (key, score) -> {
            int index = aggregatedScores.indexOf(key);
            if (index < 0) {
              aggregatedScores.indexInsert(index, key, score);
            } else {
              float currentScore = aggregatedScores.indexGet(index);
              aggregatedScores.indexReplace(index, Math.min(currentScore, score));
            }
          };
    } else if (scoreMode == ScoreMode.Total) {
      scoreAggregator = aggregatedScores::addTo;
    } else if (scoreMode == ScoreMode.Avg) {
      scoreAggregator =
          (key, score) -> {
            aggregatedScores.addTo(key, score);
            occurrences.addTo(key, 1);
          };
    } else {
      scoreAggregator =
          (_, _) -> {
            throw new UnsupportedOperationException();
          };
    }

    LongFloatFunction joinScorer;
    if (scoreMode == ScoreMode.Avg) {
      joinScorer =
          (joinValue) -> {
            float aggregatedScore = aggregatedScores.get(joinValue);
            int occurrence = occurrences.get(joinValue);
            return aggregatedScore / occurrence;
          };
    } else {
      joinScorer = aggregatedScores::get;
    }

    Collector collector;
    if (multipleValuesPerDocument) {
      collector =
          new SimpleCollector() {

            SortedNumericDocValues sortedNumericDocValues;
            Scorable scorer;

            @Override
            public void collect(int doc) throws IOException {
              if (sortedNumericDocValues.advanceExact(doc)) {
                for (int i = 0, count = sortedNumericDocValues.docValueCount(); i < count; i++) {
                  long value = sortedNumericDocValues.nextValue();
                  joinValues.add(value);
                  if (needsScore) {
                    scoreAggregator.apply(value, scorer.score());
                  }
                }
              }
            }

            @Override
            protected void doSetNextReader(LeafReaderContext context) throws IOException {
              sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField);
            }

            @Override
            public void setScorer(Scorable scorer) throws IOException {
              this.scorer = scorer;
            }

            @Override
            public org.apache.lucene.search.ScoreMode scoreMode() {
              return needsScore
                  ? org.apache.lucene.search.ScoreMode.COMPLETE
                  : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
            }
          };
    } else {
      collector =
          new SimpleCollector() {

            NumericDocValues numericDocValues;
            Scorable scorer;
            private int lastDocID = -1;

            private boolean docsInOrder(int docID) {
              if (docID < lastDocID) {
                throw new AssertionError(
                    "docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID);
              }
              lastDocID = docID;
              return true;
            }

            @Override
            public void collect(int doc) throws IOException {
              assert docsInOrder(doc);
              long value = 0;
              if (numericDocValues.advanceExact(doc)) {
                value = numericDocValues.longValue();
              }
              joinValues.add(value);
              if (needsScore) {
                scoreAggregator.apply(value, scorer.score());
              }
            }

            @Override
            protected void doSetNextReader(LeafReaderContext context) throws IOException {
              numericDocValues = DocValues.getNumeric(context.reader(), fromField);
              lastDocID = -1;
            }

            @Override
            public void setScorer(Scorable scorer) throws IOException {
              this.scorer = scorer;
            }

            @Override
            public org.apache.lucene.search.ScoreMode scoreMode() {
              return needsScore
                  ? org.apache.lucene.search.ScoreMode.COMPLETE
                  : org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
            }
          };
    }
    fromSearcher.search(fromQuery, collector);

    LongArrayList joinValuesList = new LongArrayList(joinValues.size());
    joinValuesList.addAll(joinValues);
    Arrays.sort(joinValuesList.buffer, 0, joinValuesList.size());
    Iterator<LongCursor> iterator = joinValuesList.iterator();

    final int bytesPerDim;
    final BytesRef encoded = new BytesRef();
    final PointInSetIncludingScoreQuery.Stream stream;
    if (Integer.class.equals(numericType)) {
      bytesPerDim = Integer.BYTES;
      stream =
          new PointInSetIncludingScoreQuery.Stream() {
            @Override
            public BytesRef next() {
              if (iterator.hasNext()) {
                LongCursor value = iterator.next();
                IntPoint.encodeDimension((int) value.value, encoded.bytes, 0);
                if (needsScore) {
                  score = joinScorer.apply(value.value);
                }
                return encoded;
              } else {
                return null;
              }
            }
          };
    } else if (Long.class.equals(numericType)) {
      bytesPerDim = Long.BYTES;
      stream =
          new PointInSetIncludingScoreQuery.Stream() {
            @Override
            public BytesRef next() {
              if (iterator.hasNext()) {
                LongCursor value = iterator.next();
                LongPoint.encodeDimension(value.value, encoded.bytes, 0);
                if (needsScore) {
                  score = joinScorer.apply(value.value);
                }
                return encoded;
              } else {
                return null;
              }
            }
          };
    } else if (Float.class.equals(numericType)) {
      bytesPerDim = Float.BYTES;
      stream =
          new PointInSetIncludingScoreQuery.Stream() {
            @Override
            public BytesRef next() {
              if (iterator.hasNext()) {
                LongCursor value = iterator.next();
                FloatPoint.encodeDimension(
                    Float.intBitsToFloat((int) value.value), encoded.bytes, 0);
                if (needsScore) {
                  score = joinScorer.apply(value.value);
                }
                return encoded;
              } else {
                return null;
              }
            }
          };
    } else if (Double.class.equals(numericType)) {
      bytesPerDim = Double.BYTES;
      stream =
          new PointInSetIncludingScoreQuery.Stream() {
            @Override
            public BytesRef next() {
              if (iterator.hasNext()) {
                LongCursor value = iterator.next();
                DoublePoint.encodeDimension(Double.longBitsToDouble(value.value), encoded.bytes, 0);
                if (needsScore) {
                  score = joinScorer.apply(value.value);
                }
                return encoded;
              } else {
                return null;
              }
            }
          };
    } else {
      throw new IllegalArgumentException(
          "unsupported numeric type, only Integer, Long, Float and Double are supported");
    }

    encoded.bytes = new byte[bytesPerDim];
    encoded.length = bytesPerDim;

    if (needsScore) {
      return new PointInSetIncludingScoreQuery(
          scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) {

        @Override
        protected String toString(byte[] value) {
          return toString.apply(value, numericType);
        }
      };
    } else {
      return new PointInSetQuery(toField, 1, bytesPerDim, stream) {
        @Override
        protected String toString(byte[] value) {
          return PointInSetIncludingScoreQuery.toString.apply(value, numericType);
        }
      };
    }
  }