in retail/recommendation-system/bqml-scann/tfx_pipeline/scann_evaluator.py [0:0]
def Do(self,
input_dict: Dict[Text, List[types.Artifact]],
output_dict: Dict[Text, List[types.Artifact]],
exec_properties: Dict[Text, Any]) -> None:
if 'examples' not in input_dict:
raise ValueError('Examples is missing from input dict.')
if 'model' not in input_dict:
raise ValueError('Model is missing from input dict.')
if 'evaluation' not in output_dict:
raise ValueError('Evaluation is missing from output dict.')
if 'blessing' not in output_dict:
raise ValueError('Blessing is missing from output dict.')
valid = True
self._log_startup(input_dict, output_dict, exec_properties)
embedding_files_pattern = io_utils.all_files_pattern(
artifact_utils.get_split_uri(input_dict['examples'], 'train'))
schema_file_path = artifact_utils.get_single_instance(
input_dict['schema']).uri + '/schema.pbtxt'
vocabulary, embeddings = scann_indexer.load_embeddings(
embedding_files_pattern, schema_file_path)
num_embeddings = embeddings.shape[0]
logging.info(f'{num_embeddings} embeddings are loaded.')
num_queries = int(min(num_embeddings * QUERIES_SAMPLE_RATIO, MAX_NUM_QUERIES))
logging.info(f'Sampling {num_queries} query embeddings for evaluation...')
query_embedding_indices = np.random.choice(num_embeddings, num_queries)
query_embeddings = np.take(embeddings, query_embedding_indices, axis=0)
# Load Exact matcher
exact_matcher = item_matcher.ExactMatcher(embeddings, vocabulary)
exact_matches = []
logging.info(f'Computing exact matches for the queries...')
for query in query_embeddings:
exact_matches.append(exact_matcher.match(query, NUM_NEIGBHOURS))
logging.info(f'Exact matches are computed.')
del num_embeddings, exact_matcher
# Load ScaNN index matcher
index_artifact = artifact_utils.get_single_instance(input_dict['model'])
ann_matcher = item_matcher.ScaNNMatcher(index_artifact.uri + '/serving_model_dir')
scann_matches = []
logging.info(f'Computing ScaNN matches for the queries...')
start_time = time.time()
for query in query_embeddings:
scann_matches.append(ann_matcher.match(query, NUM_NEIGBHOURS))
end_time = time.time()
logging.info(f'ScaNN matches are computed.')
# Compute average latency
elapsed_time = end_time - start_time
current_latency = elapsed_time / num_queries
# Compute recall
current_recall = 0
for exact, approx in zip(exact_matches, scann_matches):
current_recall += len(set(exact).intersection(set(approx))) / NUM_NEIGBHOURS
current_recall /= num_queries
metrics = {
'recall': current_recall,
'latency': current_latency
}
min_recall = exec_properties['min_recall']
max_latency = exec_properties['max_latency']
logging.info(f'Average latency per query achieved {current_latency}. Maximum latency allowed: {max_latency}')
logging.info(f'Recall acheived {current_recall}. Minimum recall allowed: {min_recall}')
# Validate index latency and recall
valid = (current_latency <= max_latency) and (current_recall >= min_recall)
logging.info(f'Model is valid: {valid}')
# Output the evaluation artifact.
evaluation = artifact_utils.get_single_instance(output_dict['evaluation'])
evaluation.set_string_custom_property('index_model_uri', index_artifact.uri)
evaluation.set_int_custom_property('index_model_id', index_artifact.id)
io_utils.write_string_file(
os.path.join(evaluation.uri, 'metrics'), json.dumps(metrics))
# Output the blessing artifact.
blessing = artifact_utils.get_single_instance(output_dict['blessing'])
blessing.set_string_custom_property('index_model_uri', index_artifact.uri)
blessing.set_int_custom_property('index_model_id', index_artifact.id)
if valid:
io_utils.write_string_file(os.path.join(blessing.uri, 'BLESSED'), '')
blessing.set_int_custom_property('blessed', 1)
else:
io_utils.write_string_file(os.path.join(blessing.uri, 'NOT_BLESSED'), '')
blessing.set_int_custom_property('blessed', 0)