projects/conversational-commerce-agent/data-ingestion/import_to_retail_search.py (128 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Importing converted Flipkart dataset to Google Cloud Search for Retail. """ import argparse import logging import os import time from google.cloud import retail_v2, storage from google.cloud.retail_v2.types import GcsSource, ImportErrorsConfig as ErrorsConfig logging.basicConfig(level=logging.INFO) def split_jsonl(input_file:str, output_prefix:str, max_lines:int=500) -> list[str]: """ Splits a large JSONL file into smaller files. Args: input_file: Path to the input JSONL file. output_prefix: Prefix for the output files (e.g., 'output_'). max_lines: Maximum number of lines per output file. Returns: List of output file names. """ file_names = [] folder = os.path.dirname(input_file) with open(input_file, "r", encoding="utf-8") as infile: lines = infile.readlines() chunks = [ lines[i:i + max_lines] for i in range(0, len(lines), max_lines) ] for index, chunk in enumerate(chunks): fn = f"{folder}/{output_prefix}{index}.jsonl" file_names.append(fn) with open(fn, "w", encoding="utf-8") as f: f.write("".join(chunk)) return file_names def upload_dataset_to_gcs( gcs_bucket:str, project_id:str, input_file:str) -> list[str]: """ Upload Search for Retail dataset file to GCS bucket. Args: gcs_bucket (str): Target GCS Bucket name. project_id (str): Google Cloud Project ID. input_file (str): File to be uploaded. Returns: GCS URL of the uploaded file. """ files = split_jsonl( input_file=input_file, output_prefix="flipkart-retail-search-" ) client = storage.Client(project=project_id) bucket = client.get_bucket(gcs_bucket) gcs = [] for file in files: fn = file.split("/")[-1] blob = bucket.blob(fn) blob.upload_from_filename(file) gcs.append(f"gs://{gcs_bucket}/{fn}") return gcs def set_default_branch(project_number:str, branch:str="0"): """ Set the default branch of the Search for Retail service. Args: project_number (str): Google Cloud Project number. branch (str): Branch ID. Must be one of 0,1,2. """ client = retail_v2.CatalogServiceClient() request = retail_v2.SetDefaultBranchRequest( catalog=f"projects/{project_number}/locations/global/" + \ "catalogs/default_catalog", branch_id=branch ) client.set_default_branch(request=request) def update_product_level(): """ Set the default branch of the Search for Retail service. Args: project_number (str): Google Cloud Project number. branch (str): Branch ID. Must be one of 0,1,2. """ # Update Product Level before importing data # https://cloud.google.com/retail/docs/upload-catalog#json client = retail_v2.CatalogServiceClient() catalog = retail_v2.Catalog() catalog.name = "default_catalog" catalog.display_name = "default_catalog" catalog.product_level_config = retail_v2.types.ProductLevelConfig( ) request = retail_v2.UpdateCatalogRequest( catalog=catalog, ) response = client.update_catalog(request=request) logging.info("Update Catalog: %s", response) def prepare_arguments() -> dict: """ Configure and parse commandline arguments. Returns: A Dict holds commandline arguments. """ parser = argparse.ArgumentParser( description="Converting Flipkart dataset to " + \ "Search for Retail data format." ) parser.add_argument("-i", "--input", help="Search for Retail data file path.", required=True) parser.add_argument("-g", "--gcs-bucket", help="Search for Retail import GCS bucket name.", required=True) parser.add_argument("-n", "--project-number", help="Search for Retail Project number.", required=True) parser.add_argument("-b", "--branch", help="Search for Retail Branch.", required=True) parser.add_argument("--set-default-branch", dest="set_default_branch", action="store_true") parser.add_argument("--no-set-default-branch", dest="set_default_branch", action="store_false") parser.set_defaults(set_default_branch=False) args = vars(parser.parse_args()) return { "input_file": args["input"], "gcs_bucket": args["gcs_bucket"], "project_number": args["project_number"], "branch": args["branch"], "set_default_branch": args["set_default_branch"] } def import_products(gcs_errors_path:str, gcs_url:str, project_number:str, branch:str) -> None: """ Import products to Search for Retail. Args: gcs_errors_path (str): GCS path to store import errors. gcs_url (str): GCS URL of the input file. project_number (str): Google Cloud Project number. branch (str): Retail Search Branch Id. """ # Create a client client = retail_v2.ProductServiceClient() # Initialize request argument(s) input_config = retail_v2.ProductInputConfig( gcs_source=GcsSource(input_uris=[gcs_url]) ) request = retail_v2.ImportProductsRequest( parent=(f"projects/{project_number}/locations/global/" f"catalogs/default_catalog/branches/{branch}"), input_config=input_config, errors_config=ErrorsConfig(gcs_prefix=gcs_errors_path) ) # Make the request operation = client.import_products(request=request) response = operation.result() logging.info(response) if __name__ == "__main__": params = prepare_arguments() gcs_files = upload_dataset_to_gcs( params["gcs_bucket"], params["project_number"], params["input_file"]) for gcs_file in gcs_files: logging.info("* Processing %s", gcs_file) import_products( gcs_errors_path=f"""gs://{params["gcs_bucket"]}/errors""", gcs_url=gcs_file, project_number=params["project_number"], branch=params["branch"]) time.sleep(2) if params["set_default_branch"]: set_default_branch(params["project_number"], params["branch"])