synchronized void expand()

in x-pack/plugin/graph/src/main/java/org/elasticsearch/xpack/graph/action/TransportGraphExploreAction.java [179:551]


        synchronized void expand(boolean timedOut) {
            Map<String, Set<Vertex>> lastHopFindings = hopFindings.get(currentHopNumber);
            if ((currentHopNumber >= (request.getHopNumbers() - 1)) || (lastHopFindings == null) || (lastHopFindings.size() == 0)) {
                // Either we gathered no leads from the last hop or we have
                // reached the final hop
                listener.onResponse(buildResponse(timedOut));
                return;
            }
            Hop lastHop = request.getHop(currentHopNumber);
            currentHopNumber++;
            Hop currentHop = request.getHop(currentHopNumber);

            final SearchRequest searchRequest = new SearchRequest(request.indices()).indicesOptions(request.indicesOptions());
            if (request.routing() != null) {
                searchRequest.routing(request.routing());
            }

            BoolQueryBuilder rootBool = QueryBuilders.boolQuery();

            // A single sample pool of docs is built at the root of the aggs tree.
            // For quality's sake it might have made more sense to sample top docs
            // for each of the terms from the previous hop (e.g. an initial query for "beatles"
            // may have separate doc-sample pools for significant root terms "john", "paul", "yoko" etc)
            // but I found this dramatically slowed down execution - each pool typically had different docs which
            // each had non-overlapping sets of terms that needed frequencies looking up for significant terms.
            // A common sample pool reduces the specialization that can be given to each root term but
            // ultimately is much faster to run because of the shared vocabulary in a single sample set.
            AggregationBuilder sampleAgg = null;
            if (request.sampleDiversityField() != null) {
                DiversifiedAggregationBuilder diversifiedSampleAgg = AggregationBuilders.diversifiedSampler("sample")
                    .shardSize(request.sampleSize());
                diversifiedSampleAgg.field(request.sampleDiversityField());
                diversifiedSampleAgg.maxDocsPerValue(request.maxDocsPerDiversityValue());
                sampleAgg = diversifiedSampleAgg;
            } else {
                sampleAgg = AggregationBuilders.sampler("sample").shardSize(request.sampleSize());
            }

            // Add any user-supplied criteria to the root query as a must clause
            rootBool.must(currentHop.guidingQuery());

            // Build a MUST clause that matches one of either
            // a:) include clauses supplied by the client or
            // b:) vertex terms from the previous hop.
            BoolQueryBuilder sourceTermsOrClause = QueryBuilders.boolQuery();
            addUserDefinedIncludesToQuery(currentHop, sourceTermsOrClause);
            addBigOrClause(lastHopFindings, sourceTermsOrClause);

            rootBool.must(sourceTermsOrClause);

            // Now build the agg tree that will channel the content ->
            // base agg is terms agg for terms from last wave (one per field),
            // under each is a sig_terms agg to find next candidates (again, one per field)...
            for (int fieldNum = 0; fieldNum < lastHop.getNumberVertexRequests(); fieldNum++) {
                VertexRequest lastVr = lastHop.getVertexRequest(fieldNum);
                Set<Vertex> lastWaveVerticesForField = lastHopFindings.get(lastVr.fieldName());
                if (lastWaveVerticesForField == null) {
                    continue;
                }
                SortedSet<BytesRef> terms = new TreeSet<>();
                for (Vertex v : lastWaveVerticesForField) {
                    terms.add(new BytesRef(v.getTerm()));
                }
                TermsAggregationBuilder lastWaveTermsAgg = AggregationBuilders.terms("field" + fieldNum)
                    .includeExclude(new IncludeExclude(null, null, terms, null))
                    .shardMinDocCount(1)
                    .field(lastVr.fieldName())
                    .minDocCount(1)
                    // Map execution mode used because Sampler agg keeps us
                    // focused on smaller sets of high quality docs and therefore
                    // examine smaller volumes of terms
                    .executionHint("map")
                    .size(terms.size());
                sampleAgg.subAggregation(lastWaveTermsAgg);
                for (int f = 0; f < currentHop.getNumberVertexRequests(); f++) {
                    VertexRequest vr = currentHop.getVertexRequest(f);
                    int size = vr.size();
                    if (vr.fieldName().equals(lastVr.fieldName())) {
                        // We have the potential for self-loops as we are looking at the same field so add 1 to the requested size
                        // because we need to eliminate fieldA:termA -> fieldA:termA links that are likely to be in the results.
                        size++;
                    }
                    if (request.useSignificance()) {
                        SignificantTermsAggregationBuilder nextWaveSigTerms = AggregationBuilders.significantTerms("field" + f)
                            .field(vr.fieldName())
                            .minDocCount(vr.minDocCount())
                            .shardMinDocCount(vr.shardMinDocCount())
                            .executionHint("map")
                            .size(size);
                        // nextWaveSigTerms.significanceHeuristic(new PercentageScore.PercentageScoreBuilder());
                        // Had some issues with no significant terms being returned when asking for small
                        // number of final results (eg 1) and only one shard. Setting shard_size higher helped.
                        if (size < 10) {
                            nextWaveSigTerms.shardSize(10);
                        }
                        // Alternative choices of significance algo didn't seem to be improvements....
                        // nextWaveSigTerms.significanceHeuristic(new GND.GNDBuilder(true));
                        // nextWaveSigTerms.significanceHeuristic(new ChiSquare.ChiSquareBuilder(false, true));

                        if (vr.hasIncludeClauses()) {
                            SortedSet<BytesRef> includes = vr.includeValuesAsSortedSet();
                            nextWaveSigTerms.includeExclude(new IncludeExclude(null, null, includes, null));
                            // Originally I thought users would always want the
                            // same number of results as listed in the include
                            // clause but it may be the only want the most
                            // significant e.g. in the lastfm example of
                            // plotting a single user's tastes and how that maps
                            // into a network showing only the most interesting
                            // band connections. So line below commmented out

                            // nextWaveSigTerms.size(includes.length);

                        } else if (vr.hasExcludeClauses()) {
                            nextWaveSigTerms.includeExclude(new IncludeExclude(null, null, null, vr.excludesAsSortedSet()));
                        }
                        lastWaveTermsAgg.subAggregation(nextWaveSigTerms);
                    } else {
                        TermsAggregationBuilder nextWavePopularTerms = AggregationBuilders.terms("field" + f)
                            .field(vr.fieldName())
                            .minDocCount(vr.minDocCount())
                            .shardMinDocCount(vr.shardMinDocCount())
                            // Map execution mode used because Sampler agg keeps us
                            // focused on smaller sets of high quality docs and therefore
                            // examine smaller volumes of terms
                            .executionHint("map")
                            .size(size);
                        if (vr.hasIncludeClauses()) {
                            SortedSet<BytesRef> includes = vr.includeValuesAsSortedSet();
                            nextWavePopularTerms.includeExclude(new IncludeExclude(null, null, includes, null));
                            // nextWavePopularTerms.size(includes.length);
                        } else if (vr.hasExcludeClauses()) {
                            nextWavePopularTerms.includeExclude(new IncludeExclude(null, null, null, vr.excludesAsSortedSet()));
                        }
                        lastWaveTermsAgg.subAggregation(nextWavePopularTerms);
                    }
                }
            }

            // Execute the search
            SearchSourceBuilder source = new SearchSourceBuilder().query(rootBool).aggregation(sampleAgg).size(0);
            if (request.timeout() != null) {
                // Actual resolution of timer is granularity of the interval
                // configured globally for updating estimated time.
                long timeRemainingMillis = startTime + request.timeout().millis() - threadPool.relativeTimeInMillis();
                if (timeRemainingMillis <= 0) {
                    listener.onResponse(buildResponse(true));
                    return;
                }

                source.timeout(TimeValue.timeValueMillis(timeRemainingMillis));
            }
            searchRequest.source(source);

            logger.trace("executing expansion graph search request");
            client.search(searchRequest, new DelegatingActionListener<>(listener) {
                @Override
                public void onResponse(SearchResponse searchResponse) {
                    addShardFailures(searchResponse.getShardFailures());

                    ArrayList<Connection> newConnections = new ArrayList<Connection>();
                    ArrayList<Vertex> newVertices = new ArrayList<Vertex>();
                    SingleBucketAggregation sample = searchResponse.getAggregations().get("sample");

                    // We think of the total scores as the energy-level pouring
                    // out of all the last hop's connections.
                    // Each new node encountered is given a score which is
                    // normalized between zero and one based on
                    // what percentage of the total scores its own score
                    // provides
                    double totalSignalOutput = getExpandTotalSignalStrength(lastHop, currentHop, sample);

                    // Signal output can be zero if we did not encounter any new
                    // terms as part of this stage
                    if (totalSignalOutput > 0) {
                        addAndScoreNewVertices(lastHop, currentHop, sample, totalSignalOutput, newConnections, newVertices);

                        trimNewAdditions(currentHop, newConnections, newVertices);
                    }

                    // Potentially run another round of queries to perform next"hop" - will terminate if no new additions
                    expand(searchResponse.isTimedOut());

                }

                // Add new vertices and apportion share of total signal along
                // connections
                private void addAndScoreNewVertices(
                    Hop lastHop,
                    Hop currentHop,
                    SingleBucketAggregation sample,
                    double totalSignalOutput,
                    ArrayList<Connection> newConnections,
                    ArrayList<Vertex> newVertices
                ) {
                    // Gather all matching terms into the graph and propagate
                    // signals
                    for (int j = 0; j < lastHop.getNumberVertexRequests(); j++) {
                        VertexRequest lastVr = lastHop.getVertexRequest(j);
                        Terms lastWaveTerms = sample.getAggregations().get("field" + j);
                        if (lastWaveTerms == null) {
                            // There were no terms from the previous phase that needed pursuing
                            continue;
                        }
                        List<? extends Terms.Bucket> buckets = lastWaveTerms.getBuckets();
                        for (org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket lastWaveTerm : buckets) {
                            Vertex fromVertex = getVertex(lastVr.fieldName(), lastWaveTerm.getKeyAsString());
                            for (int k = 0; k < currentHop.getNumberVertexRequests(); k++) {
                                VertexRequest vr = currentHop.getVertexRequest(k);
                                // As we travel further out into the graph we apply a
                                // decay to the signals being propagated down the various channels.
                                double decay = 0.95d;
                                if (request.useSignificance()) {
                                    SignificantTerms significantTerms = lastWaveTerm.getAggregations().get("field" + k);
                                    if (significantTerms != null) {
                                        for (Bucket bucket : significantTerms.getBuckets()) {
                                            if ((vr.fieldName().equals(fromVertex.getField()))
                                                && (bucket.getKeyAsString().equals(fromVertex.getTerm()))) {
                                                // Avoid self-joins
                                                continue;
                                            }
                                            double signalStrength = bucket.getSignificanceScore() / totalSignalOutput;

                                            // Decay the signal by the weight attached to the source vertex
                                            signalStrength = signalStrength * Math.min(decay, fromVertex.getWeight());

                                            Vertex toVertex = getVertex(vr.fieldName(), bucket.getKeyAsString());
                                            if (toVertex == null) {
                                                toVertex = addVertex(
                                                    vr.fieldName(),
                                                    bucket.getKeyAsString(),
                                                    signalStrength,
                                                    currentHopNumber,
                                                    bucket.getSupersetDf(),
                                                    bucket.getSubsetDf()
                                                );
                                                newVertices.add(toVertex);
                                            } else {
                                                toVertex.setWeight(toVertex.getWeight() + signalStrength);
                                                // We cannot (without further querying) determine an accurate number
                                                // for the foreground count of the toVertex term - if we sum the values
                                                // from each fromVertex term we may actually double-count occurrences so
                                                // the best we can do is take the maximum foreground value we have observed
                                                toVertex.setFg(Math.max(toVertex.getFg(), bucket.getSubsetDf()));
                                            }
                                            newConnections.add(addConnection(fromVertex, toVertex, signalStrength, bucket.getDocCount()));
                                        }
                                    }
                                } else {
                                    Terms terms = lastWaveTerm.getAggregations().get("field" + k);
                                    if (terms != null) {
                                        for (org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket bucket : terms.getBuckets()) {
                                            double signalStrength = bucket.getDocCount() / totalSignalOutput;
                                            // Decay the signal by the weight attached to the source vertex
                                            signalStrength = signalStrength * Math.min(decay, fromVertex.getWeight());

                                            Vertex toVertex = getVertex(vr.fieldName(), bucket.getKeyAsString());
                                            if (toVertex == null) {
                                                toVertex = addVertex(
                                                    vr.fieldName(),
                                                    bucket.getKeyAsString(),
                                                    signalStrength,
                                                    currentHopNumber,
                                                    0,
                                                    0
                                                );
                                                newVertices.add(toVertex);
                                            } else {
                                                toVertex.setWeight(toVertex.getWeight() + signalStrength);
                                            }
                                            newConnections.add(addConnection(fromVertex, toVertex, signalStrength, bucket.getDocCount()));
                                        }
                                    }
                                }
                            }
                        }
                    }
                }

                // Having let the signals from the last results rattle around the graph
                // we have adjusted weights for the various vertices we encountered.
                // Now we review these new additions and remove those with the
                // weakest weights.
                // A priority queue is used to trim vertices according to the size settings
                // requested for each field.
                private void trimNewAdditions(Hop currentHop, ArrayList<Connection> newConnections, ArrayList<Vertex> newVertices) {
                    Set<Vertex> evictions = new HashSet<>();

                    for (int k = 0; k < currentHop.getNumberVertexRequests(); k++) {
                        // For each of the fields
                        VertexRequest vr = currentHop.getVertexRequest(k);
                        if (newVertices.size() <= vr.size()) {
                            // Nothing to trim
                            continue;
                        }
                        // Get the top vertices for this field
                        VertexPriorityQueue pq = new VertexPriorityQueue(vr.size());
                        for (Vertex vertex : newVertices) {
                            if (vertex.getField().equals(vr.fieldName())) {
                                Vertex eviction = pq.insertWithOverflow(vertex);
                                if (eviction != null) {
                                    evictions.add(eviction);
                                }
                            }
                        }
                    }
                    // Remove weak new nodes and their dangling connections from the main graph
                    if (evictions.size() > 0) {
                        for (Connection connection : newConnections) {
                            if (evictions.contains(connection.getTo())) {
                                connections.remove(connection.getId());
                                removeVertex(connection.getTo());
                            }
                        }
                    }
                }
                // TODO right now we only trim down to the best N vertices. We might also want to offer
                // clients the option to limit to the best M connections. One scenario where this is required
                // is if the "from" and "to" nodes are a client-supplied set of includes e.g. a list of
                // music artists then the client may be wanting to draw only the most-interesting connections
                // between them. See https://github.com/elastic/x-plugins/issues/518#issuecomment-160186424
                // I guess clients could trim the returned connections (which all have weights) but I wonder if
                // we can do something server-side here

                // Helper method - compute the total signal of all scores in the search results
                private double getExpandTotalSignalStrength(Hop lastHop, Hop currentHop, SingleBucketAggregation sample) {
                    double totalSignalOutput = 0;
                    for (int j = 0; j < lastHop.getNumberVertexRequests(); j++) {
                        VertexRequest lastVr = lastHop.getVertexRequest(j);
                        Terms lastWaveTerms = sample.getAggregations().get("field" + j);
                        if (lastWaveTerms == null) {
                            continue;
                        }
                        List<? extends Terms.Bucket> buckets = lastWaveTerms.getBuckets();
                        for (org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket lastWaveTerm : buckets) {
                            for (int k = 0; k < currentHop.getNumberVertexRequests(); k++) {
                                VertexRequest vr = currentHop.getVertexRequest(k);
                                if (request.useSignificance()) {
                                    // Signal is based on significance score
                                    SignificantTerms significantTerms = lastWaveTerm.getAggregations().get("field" + k);
                                    if (significantTerms != null) {
                                        for (Bucket bucket : significantTerms.getBuckets()) {
                                            if ((vr.fieldName().equals(lastVr.fieldName()))
                                                && (bucket.getKeyAsString().equals(lastWaveTerm.getKeyAsString()))) {
                                                // don't count self joins (term A obviously co-occurs with term A)
                                                continue;
                                            } else {
                                                totalSignalOutput += bucket.getSignificanceScore();
                                            }
                                        }
                                    }
                                } else {
                                    // Signal is based on popularity (number of
                                    // documents)
                                    Terms terms = lastWaveTerm.getAggregations().get("field" + k);
                                    if (terms != null) {
                                        for (org.elasticsearch.search.aggregations.bucket.terms.Terms.Bucket bucket : terms.getBuckets()) {
                                            if ((vr.fieldName().equals(lastVr.fieldName()))
                                                && (bucket.getKeyAsString().equals(lastWaveTerm.getKeyAsString()))) {
                                                // don't count self joins (term A obviously co-occurs with term A)
                                                continue;
                                            } else {
                                                totalSignalOutput += bucket.getDocCount();
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                    return totalSignalOutput;
                }
            });
        }