in lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java [142:400]
public static Query createJoinQuery(
String fromField,
boolean multipleValuesPerDocument,
String toField,
Class<? extends Number> numericType,
Query fromQuery,
IndexSearcher fromSearcher,
ScoreMode scoreMode)
throws IOException {
LongHashSet joinValues = new LongHashSet();
LongFloatHashMap aggregatedScores = new LongFloatHashMap();
LongIntHashMap occurrences = new LongIntHashMap();
boolean needsScore = scoreMode != ScoreMode.None;
LongFloatProcedure scoreAggregator;
if (scoreMode == ScoreMode.Max) {
scoreAggregator =
(key, score) -> {
int index = aggregatedScores.indexOf(key);
if (index < 0) {
aggregatedScores.indexInsert(index, key, score);
} else {
float currentScore = aggregatedScores.indexGet(index);
aggregatedScores.indexReplace(index, Math.max(currentScore, score));
}
};
} else if (scoreMode == ScoreMode.Min) {
scoreAggregator =
(key, score) -> {
int index = aggregatedScores.indexOf(key);
if (index < 0) {
aggregatedScores.indexInsert(index, key, score);
} else {
float currentScore = aggregatedScores.indexGet(index);
aggregatedScores.indexReplace(index, Math.min(currentScore, score));
}
};
} else if (scoreMode == ScoreMode.Total) {
scoreAggregator = aggregatedScores::addTo;
} else if (scoreMode == ScoreMode.Avg) {
scoreAggregator =
(key, score) -> {
aggregatedScores.addTo(key, score);
occurrences.addTo(key, 1);
};
} else {
scoreAggregator =
(_, _) -> {
throw new UnsupportedOperationException();
};
}
LongFloatFunction joinScorer;
if (scoreMode == ScoreMode.Avg) {
joinScorer =
(joinValue) -> {
float aggregatedScore = aggregatedScores.get(joinValue);
int occurrence = occurrences.get(joinValue);
return aggregatedScore / occurrence;
};
} else {
joinScorer = aggregatedScores::get;
}
Collector collector;
if (multipleValuesPerDocument) {
collector =
new SimpleCollector() {
SortedNumericDocValues sortedNumericDocValues;
Scorable scorer;
@Override
public void collect(int doc) throws IOException {
if (sortedNumericDocValues.advanceExact(doc)) {
for (int i = 0, count = sortedNumericDocValues.docValueCount(); i < count; i++) {
long value = sortedNumericDocValues.nextValue();
joinValues.add(value);
if (needsScore) {
scoreAggregator.apply(value, scorer.score());
}
}
}
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField);
}
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
}
@Override
public org.apache.lucene.search.ScoreMode scoreMode() {
return needsScore
? org.apache.lucene.search.ScoreMode.COMPLETE
: org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
}
};
} else {
collector =
new SimpleCollector() {
NumericDocValues numericDocValues;
Scorable scorer;
private int lastDocID = -1;
private boolean docsInOrder(int docID) {
if (docID < lastDocID) {
throw new AssertionError(
"docs out of order: lastDocID=" + lastDocID + " vs docID=" + docID);
}
lastDocID = docID;
return true;
}
@Override
public void collect(int doc) throws IOException {
assert docsInOrder(doc);
long value = 0;
if (numericDocValues.advanceExact(doc)) {
value = numericDocValues.longValue();
}
joinValues.add(value);
if (needsScore) {
scoreAggregator.apply(value, scorer.score());
}
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
numericDocValues = DocValues.getNumeric(context.reader(), fromField);
lastDocID = -1;
}
@Override
public void setScorer(Scorable scorer) throws IOException {
this.scorer = scorer;
}
@Override
public org.apache.lucene.search.ScoreMode scoreMode() {
return needsScore
? org.apache.lucene.search.ScoreMode.COMPLETE
: org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;
}
};
}
fromSearcher.search(fromQuery, collector);
LongArrayList joinValuesList = new LongArrayList(joinValues.size());
joinValuesList.addAll(joinValues);
Arrays.sort(joinValuesList.buffer, 0, joinValuesList.size());
Iterator<LongCursor> iterator = joinValuesList.iterator();
final int bytesPerDim;
final BytesRef encoded = new BytesRef();
final PointInSetIncludingScoreQuery.Stream stream;
if (Integer.class.equals(numericType)) {
bytesPerDim = Integer.BYTES;
stream =
new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
LongCursor value = iterator.next();
IntPoint.encodeDimension((int) value.value, encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value.value);
}
return encoded;
} else {
return null;
}
}
};
} else if (Long.class.equals(numericType)) {
bytesPerDim = Long.BYTES;
stream =
new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
LongCursor value = iterator.next();
LongPoint.encodeDimension(value.value, encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value.value);
}
return encoded;
} else {
return null;
}
}
};
} else if (Float.class.equals(numericType)) {
bytesPerDim = Float.BYTES;
stream =
new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
LongCursor value = iterator.next();
FloatPoint.encodeDimension(
Float.intBitsToFloat((int) value.value), encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value.value);
}
return encoded;
} else {
return null;
}
}
};
} else if (Double.class.equals(numericType)) {
bytesPerDim = Double.BYTES;
stream =
new PointInSetIncludingScoreQuery.Stream() {
@Override
public BytesRef next() {
if (iterator.hasNext()) {
LongCursor value = iterator.next();
DoublePoint.encodeDimension(Double.longBitsToDouble(value.value), encoded.bytes, 0);
if (needsScore) {
score = joinScorer.apply(value.value);
}
return encoded;
} else {
return null;
}
}
};
} else {
throw new IllegalArgumentException(
"unsupported numeric type, only Integer, Long, Float and Double are supported");
}
encoded.bytes = new byte[bytesPerDim];
encoded.length = bytesPerDim;
if (needsScore) {
return new PointInSetIncludingScoreQuery(
scoreMode, fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) {
@Override
protected String toString(byte[] value) {
return toString.apply(value, numericType);
}
};
} else {
return new PointInSetQuery(toField, 1, bytesPerDim, stream) {
@Override
protected String toString(byte[] value) {
return PointInSetIncludingScoreQuery.toString.apply(value, numericType);
}
};
}
}