migrate.py (150 lines of code) (raw):

from dotenv import load_dotenv from azure.identity import DefaultAzureCredential from azure.core.credentials import AzureKeyCredential import os from azure.search.documents import SearchClient from azure.search.documents.indexes import SearchIndexClient import tqdm import time from azure.search.documents.indexes.models import ( VectorSearch, SearchIndex, ) import json from azure.ai.inference import EmbeddingsClient load_dotenv("credentials.env") #, override=True) # take environment variables from .env. # Variables not used here do not need to be updated in your .env file source_endpoint = os.environ["AZURE_SEARCH_SERVICE_ENDPOINT"] source_credential = AzureKeyCredential(os.environ["AZURE_SEARCH_ADMIN_KEY"]) if len(os.environ["AZURE_SEARCH_ADMIN_KEY"]) > 0 else DefaultAzureCredential() source_index_name = os.environ["AZURE_SEARCH_INDEX"] # Default to same service for copying index target_endpoint = os.environ["AZURE_TARGET_SEARCH_SERVICE_ENDPOINT"] if len(os.environ["AZURE_TARGET_SEARCH_SERVICE_ENDPOINT"]) > 0 else source_endpoint target_credential = AzureKeyCredential(os.environ["AZURE_TARGET_SEARCH_ADMIN_KEY"]) if len(os.environ["AZURE_TARGET_SEARCH_ADMIN_KEY"]) > 0 else DefaultAzureCredential() target_index_name = os.environ["AZURE_TARGET_SEARCH_INDEX"] def initialize_embedding_model(): embeddings_client = None if "openai" in os.environ["AZURE_AI_EMBEDDINGS_ENDPOINT"]: print("initializing embeddings client AOAI") embeddings_client = EmbeddingsClient( endpoint=os.environ["AZURE_AI_EMBEDDINGS_ENDPOINT"], credential=AzureKeyCredential(os.environ["AZURE_AI_EMBEDDINGS_KEY"]), api_version=os.environ["AZURE_AI_EMBEDDINGS_API_VERSION"] ) else: print("initializing embeddings client not AOAI") embeddings_client = EmbeddingsClient( endpoint=os.environ["AZURE_AI_EMBEDDINGS_ENDPOINT"], credential=AzureKeyCredential(os.environ["AZURE_AI_EMBEDDINGS_KEY"]) ) return embeddings_client def get_embedding(embeddings_model, text): response = embeddings_model.embed(input=[text]) return response.data[0]["embedding"] def create_clients(endpoint, credential, index_name): search_client = SearchClient(endpoint=endpoint, index_name=index_name, credential=credential) index_client = SearchIndexClient(endpoint=endpoint, credential=credential) return search_client, index_client def total_count(search_client): response = search_client.search(include_total_count=True, search_text="*", top=0) return response.get_count() def search_results_with_filter(search_client, key_field_name): last_item = None response = search_client.search(search_text="*", top=100000, order_by=key_field_name).by_page() while True: for page in response: page = list(page) if len(page) > 0: last_item = page[-1] yield page else: last_item = None if last_item: response = search_client.search(search_text="*", top=100000, order_by=key_field_name, filter=f"{key_field_name} gt '{last_item[key_field_name]}'").by_page() else: break def search_results_without_filter(search_client): response = search_client.search(search_text="*", top=100000).by_page() for page in response: page = list(page) yield page def add_api_key(vectors): vectorizers = [] for vectorizer in vectors["vectorizers"]: if vectorizer["kind"] == "azureOpenAI": vectorizer["azureOpenAIParameters"]["apiKey"] = os.environ["azureOpenAI_API_KEY"] elif vectorizer["kind"] == "customWebApi": vectorizer["customWebApiParameters"]["httpHeaders"]["x-functions-key"] = os.environ["customWebApi_API_KEY"] vectorizers.append(vectorizer) vectors["vectorizers"] = vectorizers return vectors def backup_and_restore_index(source_endpoint, source_key, source_index_name, target_endpoint, target_key, target_index_name): # Create search and index clients source_search_client, source_index_client = create_clients(source_endpoint, source_key, source_index_name) target_search_client, target_index_client = create_clients(target_endpoint, target_key, target_index_name) # Load target vector profiles vectors = json.load(open("vectors.json")) vectors = add_api_key(vectors) vector_search = VectorSearch.from_dict(vectors) # Load json file for column mapping to vector vector_mapping = json.load(open("vector_mapping.json")) embeddings_model = initialize_embedding_model() # Get the source index definition source_index = source_index_client.get_index(name=source_index_name) target_fields = [] non_retrievable_fields = [] for field in source_index.fields: if field.hidden == True: non_retrievable_fields.append(field) if field.key == True: key_field = field if field.vector_search_dimensions is not None: for key in vector_mapping: if key["target"] == field.name: field.vector_search_dimensions = key["vector_length"] target_fields.append(field) if not key_field: raise Exception("Key Field Not Found") if len(non_retrievable_fields) > 0: print(f"WARNING: The following fields are not marked as retrievable and cannot be backed up and restored: {', '.join(f.name for f in non_retrievable_fields)}") # Create target index with the same definition # source_index.name = target_index_name target_index = SearchIndex(name=target_index_name, fields=target_fields, vector_search=vector_search, semantic_search=source_index.semantic_search) target_index_client.create_or_update_index(target_index) document_count = total_count(source_search_client) can_use_filter = key_field.sortable and key_field.filterable if not can_use_filter: print("WARNING: The key field is not filterable or not sortable. A maximum of 100,000 records can be backed up and restored.") # Backup and restore documents all_documents = search_results_with_filter(source_search_client, key_field.name) if can_use_filter else search_results_without_filter(source_search_client) print("Backing up and restoring documents:") failed_documents = 0 failed_keys = [] with tqdm.tqdm(total=document_count) as progress_bar: for page in all_documents: new_page=[] for document in page: for key in vector_mapping: source = key["source"] embedding_text = get_embedding(embeddings_model, document[source]) document[key["target"]] = embedding_text new_page.append(document) # print(document) result = target_search_client.upload_documents(documents=new_page) progress_bar.update(len(result)) for item in result: if item.succeeded is not True: failed_documents += 1 failed_keys.append(page[result.index_of(item)].id) print(f"Document upload error: {item.error.message}") if failed_documents > 0: print(f"Failed documents: {failed_documents}") print(f"Failed document keys: {failed_keys}") else: print("All documents uploaded successfully.") print(f"Successfully backed up '{source_index_name}' and restored to '{target_index_name}'") return source_search_client, target_search_client, all_documents def verify_counts(source_search_client, target_search_client): source_document_count = source_search_client.get_document_count() time.sleep(10) target_document_count = target_search_client.get_document_count() print(f"Source document count: {source_document_count}") print(f"Target document count: {target_document_count}") if source_document_count == target_document_count: print("Document counts match.") else: print("Document counts do not match.") if __name__ == "__main__": print("Starting migration script") source_search_client, target_search_client, all_documents = backup_and_restore_index(source_endpoint, source_credential, source_index_name, target_endpoint, target_credential, target_index_name) # Call the verify_counts function with the search_clients returned by the backup_and_restore_index function verify_counts(source_search_client, target_search_client) print("Migration script completed")