def Do()

in retail/recommendation-system/bqml-scann/tfx_pipeline/scann_evaluator.py [0:0]


  def Do(self, 
         input_dict: Dict[Text, List[types.Artifact]],
         output_dict: Dict[Text, List[types.Artifact]],
         exec_properties: Dict[Text, Any]) -> None:
    
    if 'examples' not in input_dict:
      raise ValueError('Examples is missing from input dict.')
    if 'model' not in input_dict:
      raise ValueError('Model is missing from input dict.')
    if 'evaluation' not in output_dict:
      raise ValueError('Evaluation is missing from output dict.')
    if 'blessing' not in output_dict:
      raise ValueError('Blessing is missing from output dict.')
    
    valid = True
        
    self._log_startup(input_dict, output_dict, exec_properties)
    
    embedding_files_pattern = io_utils.all_files_pattern(
      artifact_utils.get_split_uri(input_dict['examples'], 'train'))
    
    schema_file_path = artifact_utils.get_single_instance(
      input_dict['schema']).uri + '/schema.pbtxt'
    
    vocabulary, embeddings = scann_indexer.load_embeddings(
      embedding_files_pattern, schema_file_path)
    
    num_embeddings = embeddings.shape[0]
    logging.info(f'{num_embeddings} embeddings are loaded.')
    num_queries = int(min(num_embeddings * QUERIES_SAMPLE_RATIO, MAX_NUM_QUERIES))
    logging.info(f'Sampling {num_queries} query embeddings for evaluation...')
    query_embedding_indices = np.random.choice(num_embeddings, num_queries)
    query_embeddings = np.take(embeddings, query_embedding_indices, axis=0)

    # Load Exact matcher
    exact_matcher = item_matcher.ExactMatcher(embeddings, vocabulary)
    exact_matches = []
    logging.info(f'Computing exact matches for the queries...')
    for query in query_embeddings:
      exact_matches.append(exact_matcher.match(query, NUM_NEIGBHOURS))
    logging.info(f'Exact matches are computed.')
    del num_embeddings, exact_matcher
    
    # Load ScaNN index matcher
    index_artifact = artifact_utils.get_single_instance(input_dict['model'])
    ann_matcher = item_matcher.ScaNNMatcher(index_artifact.uri + '/serving_model_dir')
    scann_matches = []
    logging.info(f'Computing ScaNN matches for the queries...')
    start_time = time.time()
    for query in query_embeddings:
      scann_matches.append(ann_matcher.match(query, NUM_NEIGBHOURS))
    end_time = time.time()
    logging.info(f'ScaNN matches are computed.')
    
    # Compute average latency
    elapsed_time = end_time - start_time
    current_latency = elapsed_time / num_queries

    # Compute recall
    current_recall = 0
    for exact, approx in zip(exact_matches, scann_matches):
      current_recall += len(set(exact).intersection(set(approx))) / NUM_NEIGBHOURS
    current_recall /= num_queries
    
    metrics = {
      'recall': current_recall,
      'latency': current_latency
    }
    
    min_recall = exec_properties['min_recall']
    max_latency = exec_properties['max_latency']
    
    logging.info(f'Average latency per query achieved {current_latency}. Maximum latency allowed: {max_latency}')
    logging.info(f'Recall acheived {current_recall}. Minimum recall allowed: {min_recall}')
    
    # Validate index latency and recall
    valid = (current_latency <= max_latency) and (current_recall >= min_recall)
    logging.info(f'Model is valid: {valid}')
    
    # Output the evaluation artifact.
    evaluation = artifact_utils.get_single_instance(output_dict['evaluation'])
    evaluation.set_string_custom_property('index_model_uri', index_artifact.uri)
    evaluation.set_int_custom_property('index_model_id', index_artifact.id)
    io_utils.write_string_file(
      os.path.join(evaluation.uri, 'metrics'), json.dumps(metrics))
    
    # Output the blessing artifact.
    blessing = artifact_utils.get_single_instance(output_dict['blessing'])
    blessing.set_string_custom_property('index_model_uri', index_artifact.uri)
    blessing.set_int_custom_property('index_model_id', index_artifact.id)

    if valid:
      io_utils.write_string_file(os.path.join(blessing.uri, 'BLESSED'), '')
      blessing.set_int_custom_property('blessed', 1)
    else:
      io_utils.write_string_file(os.path.join(blessing.uri, 'NOT_BLESSED'), '')
      blessing.set_int_custom_property('blessed', 0)