def compute_pmi()

in retail/recommendation-system/bqml-scann/tfx_pipeline/bq_components.py [0:0]


def compute_pmi(
  project_id: Parameter[str],
  bq_dataset: Parameter[str],
  min_item_frequency: Parameter[int],
  max_group_size: Parameter[int],
  item_cooc: OutputArtifact[Dataset]):
  
  stored_proc = f'{bq_dataset}.sp_ComputePMI'
  query = f'''
      DECLARE min_item_frequency INT64;
      DECLARE max_group_size INT64;

      SET min_item_frequency = {min_item_frequency};
      SET max_group_size = {max_group_size};

      CALL {stored_proc}(min_item_frequency, max_group_size);
  '''
  result_table = 'item_cooc'

  logging.info(f'Starting computing PMI...')
  
  client = bigquery.Client(project=project_id)
  query_job = client.query(query)
  query_job.result() # Wait for the job to complete
  
  logging.info(f'Items PMI computation completed. Output in {bq_dataset}.{result_table}.')
  
  # Write the location of the output table to metadata.  
  item_cooc.set_string_custom_property('bq_dataset', bq_dataset)
  item_cooc.set_string_custom_property('bq_result_table', result_table)