QueryGenerator.py (117 lines of code) (raw):
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from pandas.api.types import is_list_like
import itertools
import heapq
from functools import total_ordering
@total_ordering
class Query:
def __init__(self, qid, idx, query, parameters, param_idx=0):
self.data_idx = idx
self.qid = qid
self.query = query
self.parameters = parameters
self.param_idx = param_idx # used to order queries
def __iter__(self):
yield self.data_idx
yield self.qid
yield self.query
yield self.parameters
def info(self):
return {
'query_idx': self.data_idx,
'qid': self.qid,
'query': self.query,
'query_parameters': self.parameters,
}
def __str__(self):
return f"{self.query} {self.parameters} {self.data_idx}"
def __lt__(self, b):
return self.data_idx < b.data_idx
def __eq__(self,b):
return self.data_idx == b.data_idx
class QueryGenerator:
name = "generic_QG"
END = 1000000000
def __init__(self):
self.prepared = False
def connectDataGenerator(self, data_generator):
self.data_generator = data_generator
def getName(self):
return self.name
def getID(self):
return self.name
class DataGeneratorSeq:
def __init__(self, data_generator=None, length=None, by=None):
self.data_generator = data_generator
self.length = length
self.by = by
def genSeq(self):
if self.length is not None:
n = len(self.data_generator)
assert(n > 1)
by = (n-1) / self.length
else:
n = QueryGenerator.END
by = self.by
i = 1
while i*by < len(self.data_generator):
yield int(i*by)
i += 1
if i*by != n-1:
yield QueryGenerator.END
class ConfigQueryGenerator(QueryGenerator):
name = 'config_QG'
def __init__(self, queries, indices=[QueryGenerator.END], parameters=None):
self.queries = queries
self.indices = indices
self.query_parameters = parameters
def genQueries(self):
qid = 0
if isinstance(self.queries, list):
queries = self.queries
else:
queries = [self.queries]
if is_list_like(self.query_parameters) and not isinstance(self.query_parameters, dict):
query_parameters = self.query_parameters
else:
query_parameters = [self.query_parameters]
if isinstance(self.indices, DataGeneratorSeq):
indices = self.indices.genSeq()
else:
indices = self.indices
for q, idx, params in itertools.product(queries, indices, query_parameters):
yield Query(qid, idx, q, params)
qid += 1
def connectDataGenerator(self, data_generator):
self.data_generator = data_generator
if isinstance(self.indices, DataGeneratorSeq):
self.indices.data_generator = data_generator
class ChainQueryGenerators(QueryGenerator):
name = 'chained_QG'
def __init__(self, generators=[]):
super().__init__()
self.query_generators = generators
self.heap = []
def genQueries(self):
# reassign the qid's and ensure queries are ordered by idx
iters = [qg.genQueries() for qg in self.query_generators]
for i, it in enumerate(iters):
q = next(it, None)
if q is not None:
self.heap.append((q, i))
heapq.heapify(self.heap)
qid = 0
while self.heap:
q, i = heapq.heappop(self.heap)
q.qid = qid
qid += 1
yield q
nextq = next(iters[i], None)
if nextq is not None:
heapq.heappush(self.heap, (nextq, i))
def connectDataGenerator(self, data_generator):
for qg in self.query_generators:
qg.connectDataGenerator(data_generator)
class TopKQueryGenerator(ConfigQueryGenerator):
def __init__(self, k_values, num_queries, data_generator=None, **kwargs):
super().__init__(queries="topk",
indices=DataGeneratorSeq(data_generator, length=num_queries),
parameters=k_values)