in x-pack/plugin/graph/src/main/java/org/elasticsearch/xpack/graph/action/TransportGraphExploreAction.java [179:551]
synchronized void expand(boolean timedOut) {
Map<String, Set<Vertex>> lastHopFindings = hopFindings.get(currentHopNumber);
if ((currentHopNumber >= (request.getHopNumbers() - 1)) || (lastHopFindings == null) || (lastHopFindings.size() == 0)) {
// Either we gathered no leads from the last hop or we have
// reached the final hop
listener.onResponse(buildResponse(timedOut));
return;
}
Hop lastHop = request.getHop(currentHopNumber);
currentHopNumber++;
Hop currentHop = request.getHop(currentHopNumber);
final SearchRequest searchRequest = new SearchRequest(request.indices()).indicesOptions(request.indicesOptions());
if (request.routing() != null) {
searchRequest.routing(request.routing());
}
BoolQueryBuilder rootBool = QueryBuilders.boolQuery();
// A single sample pool of docs is built at the root of the aggs tree.
// For quality's sake it might have made more sense to sample top docs
// for each of the terms from the previous hop (e.g. an initial query for "beatles"
// may have separate doc-sample pools for significant root terms "john", "paul", "yoko" etc)
// but I found this dramatically slowed down execution - each pool typically had different docs which
// each had non-overlapping sets of terms that needed frequencies looking up for significant terms.
// A common sample pool reduces the specialization that can be given to each root term but
// ultimately is much faster to run because of the shared vocabulary in a single sample set.
AggregationBuilder sampleAgg = null;
if (request.sampleDiversityField() != null) {
DiversifiedAggregationBuilder diversifiedSampleAgg = AggregationBuilders.diversifiedSampler("sample")
.shardSize(request.sampleSize());
diversifiedSampleAgg.field(request.sampleDiversityField());
diversifiedSampleAgg.maxDocsPerValue(request.maxDocsPerDiversityValue());
sampleAgg = diversifiedSampleAgg;
} else {
sampleAgg = AggregationBuilders.sampler("sample").shardSize(request.sampleSize());
}
// Add any user-supplied criteria to the root query as a must clause
rootBool.must(currentHop.guidingQuery());
// Build a MUST clause that matches one of either
// a:) include clauses supplied by the client or
// b:) vertex terms from the previous hop.
BoolQueryBuilder sourceTermsOrClause = QueryBuilders.boolQuery();
addUserDefinedIncludesToQuery(currentHop, sourceTermsOrClause);
addBigOrClause(lastHopFindings, sourceTermsOrClause);
rootBool.must(sourceTermsOrClause);
// Now build the agg tree that will channel the content ->
// base agg is terms agg for terms from last wave (one per field),
// under each is a sig_terms agg to find next candidates (again, one per field)...
for (int fieldNum = 0; fieldNum < lastHop.getNumberVertexRequests(); fieldNum++) {
VertexRequest lastVr = lastHop.getVertexRequest(fieldNum);
Set<Vertex> lastWaveVerticesForField = lastHopFindings.get(lastVr.fieldName());
if (lastWaveVerticesForField == null) {
continue;
}
SortedSet<BytesRef> terms = new TreeSet<>();
for (Vertex v : lastWaveVerticesForField) {
terms.add(new BytesRef(v.getTerm()));
}
TermsAggregationBuilder lastWaveTermsAgg = AggregationBuilders.terms("field" + fieldNum)
.includeExclude(new IncludeExclude(null, null, terms, null))
.shardMinDocCount(1)
.field(lastVr.fieldName())
.minDocCount(1)
// Map execution mode used because Sampler agg keeps us
// focused on smaller sets of high quality docs and therefore
// examine smaller volumes of terms
.executionHint("map")
.size(terms.size());
sampleAgg.subAggregation(lastWaveTermsAgg);
for (int f = 0; f < currentHop.getNumberVertexRequests(); f++) {
VertexRequest vr = currentHop.getVertexRequest(f);
int size = vr.size();
if (vr.fieldName().equals(lastVr.fieldName())) {
// We have the potential for self-loops as we are looking at the same field so add 1 to the requested size
// because we need to eliminate fieldA:termA -> fieldA:termA links that are likely to be in the results.
size++;
}
if (request.useSignificance()) {
SignificantTermsAggregationBuilder nextWaveSigTerms = AggregationBuilders.significantTerms("field" + f)
.field(vr.fieldName())
.minDocCount(vr.minDocCount())
.shardMinDocCount(vr.shardMinDocCount())
.executionHint("map")
.size(size);
// nextWaveSigTerms.significanceHeuristic(new PercentageScore.PercentageScoreBuilder());
// Had some issues with no significant terms being returned when asking for small
// number of final results (eg 1) and only one shard. Setting shard_size higher helped.
if (size < 10) {
nextWaveSigTerms.shardSize(10);
}
// Alternative choices of significance algo didn't seem to be improvements....
// nextWaveSigTerms.significanceHeuristic(new GND.GNDBuilder(true));
// nextWaveSigTerms.significanceHeuristic(new ChiSquare.ChiSquareBuilder(false, true));
if (vr.hasIncludeClauses()) {
SortedSet<BytesRef> includes = vr.includeValuesAsSortedSet();
nextWaveSigTerms.includeExclude(new IncludeExclude(null, null, includes, null));
// Originally I thought users would always want the
// same number of results as listed in the include
// clause but it may be the only want the most
// significant e.g. in the lastfm example of
// plotting a single user's tastes and how that maps
// into a network showing only the most interesting
// band connections. So line below commmented out
// nextWaveSigTerms.size(includes.length);
} else if (vr.hasExcludeClauses()) {
nextWaveSigTerms.includeExclude(new IncludeExclude(null, null, null, vr.excludesAsSortedSet()));
}
lastWaveTermsAgg.subAggregation(nextWaveSigTerms);
} else {
TermsAggregationBuilder nextWavePopularTerms = AggregationBuilders.terms("field" + f)
.field(vr.fieldName())
.minDocCount(vr.minDocCount())
.shardMinDocCount(vr.shardMinDocCount())
// Map execution mode used because Sampler agg keeps us
// focused on smaller sets of high quality docs and therefore
// examine smaller volumes of terms
.executionHint("map")
.size(size);
if (vr.hasIncludeClauses()) {
SortedSet<BytesRef> includes = vr.includeValuesAsSortedSet();
nextWavePopularTerms.includeExclude(new IncludeExclude(null, null, includes, null));
// nextWavePopularTerms.size(includes.length);
} else if (vr.hasExcludeClauses()) {
nextWavePopularTerms.includeExclude(new IncludeExclude(null, null, null, vr.excludesAsSortedSet()));
}
lastWaveTermsAgg.subAggregation(nextWavePopularTerms);
}
}
}
// Execute the search
SearchSourceBuilder source = new SearchSourceBuilder().query(rootBool).aggregation(sampleAgg).size(0);
if (request.timeout() != null) {
// Actual resolution of timer is granularity of the interval
// configured globally for updating estimated time.
long timeRemainingMillis = startTime + request.timeout().millis() - threadPool.relativeTimeInMillis();
if (timeRemainingMillis <= 0) {
listener.onResponse(buildResponse(true));
return;
}
source.timeout(TimeValue.timeValueMillis(timeRemainingMillis));
}
searchRequest.source(source);
logger.trace("executing expansion graph search request");
client.search(searchRequest, new DelegatingActionListener<>(listener) {
@Override
public void onResponse(SearchResponse searchResponse) {
addShardFailures(searchResponse.getShardFailures());
ArrayList<Connection> newConnections = new ArrayList<Connection>();
ArrayList<Vertex> newVertices = new ArrayList<Vertex>();
SingleBucketAggregation sample = searchResponse.getAggregations().get("sample");
// We think of the total scores as the energy-level pouring
// out of all the last hop's connections.
// Each new node encountered is given a score which is
// normalized between zero and one based on
// what percentage of the total scores its own score
// provides
double totalSignalOutput = getExpandTotalSignalStrength(lastHop, currentHop, sample);
// Signal output can be zero if we did not encounter any new
// terms as part of this stage
if (totalSignalOutput > 0) {
addAndScoreNewVertices(lastHop, currentHop, sample, totalSignalOutput, newConnections, newVertices);
trimNewAdditions(currentHop, newConnections, newVertices);
}
// Potentially run another round of queries to perform next"hop" - will terminate if no new additions
expand(searchResponse.isTimedOut());
}
// Add new vertices and apportion share of total signal along
// connections
private void addAndScoreNewVertices(
Hop lastHop,
Hop currentHop,
SingleBucketAggregation sample,
double totalSignalOutput,
ArrayList<Connection> newConnections,
ArrayList<Vertex> newVertices
) {
// Gather all matching terms into the graph and propagate
// signals
for (int j = 0; j < lastHop.getNumberVertexRequests(); j++) {
VertexRequest lastVr = lastHop.getVertexRequest(j);
Terms lastWaveTerms = sample.getAggregations().get("field" + j);
if (lastWaveTerms == null) {
// There were no terms from the previous phase that needed pursuing
continue;
}
List<? extends Terms.Bucket> buckets = lastWaveTerms.getBuckets();
for (org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket lastWaveTerm : buckets) {
Vertex fromVertex = getVertex(lastVr.fieldName(), lastWaveTerm.getKeyAsString());
for (int k = 0; k < currentHop.getNumberVertexRequests(); k++) {
VertexRequest vr = currentHop.getVertexRequest(k);
// As we travel further out into the graph we apply a
// decay to the signals being propagated down the various channels.
double decay = 0.95d;
if (request.useSignificance()) {
SignificantTerms significantTerms = lastWaveTerm.getAggregations().get("field" + k);
if (significantTerms != null) {
for (Bucket bucket : significantTerms.getBuckets()) {
if ((vr.fieldName().equals(fromVertex.getField()))
&& (bucket.getKeyAsString().equals(fromVertex.getTerm()))) {
// Avoid self-joins
continue;
}
double signalStrength = bucket.getSignificanceScore() / totalSignalOutput;
// Decay the signal by the weight attached to the source vertex
signalStrength = signalStrength * Math.min(decay, fromVertex.getWeight());
Vertex toVertex = getVertex(vr.fieldName(), bucket.getKeyAsString());
if (toVertex == null) {
toVertex = addVertex(
vr.fieldName(),
bucket.getKeyAsString(),
signalStrength,
currentHopNumber,
bucket.getSupersetDf(),
bucket.getSubsetDf()
);
newVertices.add(toVertex);
} else {
toVertex.setWeight(toVertex.getWeight() + signalStrength);
// We cannot (without further querying) determine an accurate number
// for the foreground count of the toVertex term - if we sum the values
// from each fromVertex term we may actually double-count occurrences so
// the best we can do is take the maximum foreground value we have observed
toVertex.setFg(Math.max(toVertex.getFg(), bucket.getSubsetDf()));
}
newConnections.add(addConnection(fromVertex, toVertex, signalStrength, bucket.getDocCount()));
}
}
} else {
Terms terms = lastWaveTerm.getAggregations().get("field" + k);
if (terms != null) {
for (org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket bucket : terms.getBuckets()) {
double signalStrength = bucket.getDocCount() / totalSignalOutput;
// Decay the signal by the weight attached to the source vertex
signalStrength = signalStrength * Math.min(decay, fromVertex.getWeight());
Vertex toVertex = getVertex(vr.fieldName(), bucket.getKeyAsString());
if (toVertex == null) {
toVertex = addVertex(
vr.fieldName(),
bucket.getKeyAsString(),
signalStrength,
currentHopNumber,
0,
0
);
newVertices.add(toVertex);
} else {
toVertex.setWeight(toVertex.getWeight() + signalStrength);
}
newConnections.add(addConnection(fromVertex, toVertex, signalStrength, bucket.getDocCount()));
}
}
}
}
}
}
}
// Having let the signals from the last results rattle around the graph
// we have adjusted weights for the various vertices we encountered.
// Now we review these new additions and remove those with the
// weakest weights.
// A priority queue is used to trim vertices according to the size settings
// requested for each field.
private void trimNewAdditions(Hop currentHop, ArrayList<Connection> newConnections, ArrayList<Vertex> newVertices) {
Set<Vertex> evictions = new HashSet<>();
for (int k = 0; k < currentHop.getNumberVertexRequests(); k++) {
// For each of the fields
VertexRequest vr = currentHop.getVertexRequest(k);
if (newVertices.size() <= vr.size()) {
// Nothing to trim
continue;
}
// Get the top vertices for this field
VertexPriorityQueue pq = new VertexPriorityQueue(vr.size());
for (Vertex vertex : newVertices) {
if (vertex.getField().equals(vr.fieldName())) {
Vertex eviction = pq.insertWithOverflow(vertex);
if (eviction != null) {
evictions.add(eviction);
}
}
}
}
// Remove weak new nodes and their dangling connections from the main graph
if (evictions.size() > 0) {
for (Connection connection : newConnections) {
if (evictions.contains(connection.getTo())) {
connections.remove(connection.getId());
removeVertex(connection.getTo());
}
}
}
}
// TODO right now we only trim down to the best N vertices. We might also want to offer
// clients the option to limit to the best M connections. One scenario where this is required
// is if the "from" and "to" nodes are a client-supplied set of includes e.g. a list of
// music artists then the client may be wanting to draw only the most-interesting connections
// between them. See https://github.com/elastic/x-plugins/issues/518#issuecomment-160186424
// I guess clients could trim the returned connections (which all have weights) but I wonder if
// we can do something server-side here
// Helper method - compute the total signal of all scores in the search results
private double getExpandTotalSignalStrength(Hop lastHop, Hop currentHop, SingleBucketAggregation sample) {
double totalSignalOutput = 0;
for (int j = 0; j < lastHop.getNumberVertexRequests(); j++) {
VertexRequest lastVr = lastHop.getVertexRequest(j);
Terms lastWaveTerms = sample.getAggregations().get("field" + j);
if (lastWaveTerms == null) {
continue;
}
List<? extends Terms.Bucket> buckets = lastWaveTerms.getBuckets();
for (org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket lastWaveTerm : buckets) {
for (int k = 0; k < currentHop.getNumberVertexRequests(); k++) {
VertexRequest vr = currentHop.getVertexRequest(k);
if (request.useSignificance()) {
// Signal is based on significance score
SignificantTerms significantTerms = lastWaveTerm.getAggregations().get("field" + k);
if (significantTerms != null) {
for (Bucket bucket : significantTerms.getBuckets()) {
if ((vr.fieldName().equals(lastVr.fieldName()))
&& (bucket.getKeyAsString().equals(lastWaveTerm.getKeyAsString()))) {
// don't count self joins (term A obviously co-occurs with term A)
continue;
} else {
totalSignalOutput += bucket.getSignificanceScore();
}
}
}
} else {
// Signal is based on popularity (number of
// documents)
Terms terms = lastWaveTerm.getAggregations().get("field" + k);
if (terms != null) {
for (org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket bucket : terms.getBuckets()) {
if ((vr.fieldName().equals(lastVr.fieldName()))
&& (bucket.getKeyAsString().equals(lastWaveTerm.getKeyAsString()))) {
// don't count self joins (term A obviously co-occurs with term A)
continue;
} else {
totalSignalOutput += bucket.getDocCount();
}
}
}
}
}
}
}
return totalSignalOutput;
}
});
}