demo-python/code/embeddings/cohere-embeddings/scripts/setup.py (250 lines of code) (raw):
import asyncio
from azure.ai.ml import MLClient
from azure.core.credentials import AzureNamedKeyCredential
from azure.mgmt.storage.aio import StorageManagementClient
from azure.identity.aio import AzureCliCredential
import azure.identity
from azure.storage.blob.aio import BlobServiceClient
from azure.search.documents.indexes.aio import SearchIndexClient, SearchIndexerClient
from azure.search.documents.indexes.models import (
SearchIndex,
SearchIndexerSkillset,
SearchIndexer,
SearchIndexerDataSourceConnection,
SearchIndexerDataContainer,
SearchIndexerDataSourceType,
SearchIndexerIndexProjections,
SearchIndexerIndexProjectionSelector,
SearchIndexerIndexProjectionsParameters,
IndexProjectionMode,
AzureMachineLearningSkill,
SplitSkill,
InputFieldMappingEntry,
OutputFieldMappingEntry,
SearchFieldDataType,
ScalarQuantizationCompressionConfiguration,
VectorSearchProfile,
VectorSearch,
SearchField,
SearchableField,
SimpleField,
AzureMachineLearningVectorizer,
AzureMachineLearningParameters,
HnswAlgorithmConfiguration,
LexicalAnalyzerName,
SemanticConfiguration,
SemanticField,
SemanticPrioritizedFields,
SemanticSearch
)
import os
import glob
current_file_directory = os.path.dirname(os.path.abspath(__file__))
samples_path = os.path.join(current_file_directory, "..", "..", "..", "..", "..", "data", "benefitdocs")
def create_credential():
return AzureCliCredential(tenant_id=os.getenv("AZURE_TENANT_ID", None))
def create_sync_credential():
return azure.identity.AzureCliCredential(tenant_id=os.getenv("AZURE_TENANT_ID", None))
async def main():
async with create_credential() as credential:
print("Uploading sample documents...")
async with BlobServiceClient.from_connection_string(conn_str=await get_storage_connection_string()) as blob_service_client:
await upload_documents(blob_service_client)
print("Creating index...")
search_endpoint = os.getenv("AZURE_SEARCH_ENDPOINT")
async with SearchIndexClient(endpoint=search_endpoint, credential=credential) as search_index_client:
await create_index(search_index_client)
async with SearchIndexerClient(endpoint=search_endpoint, credential=credential) as search_indexer_client:
print("Creating skillset...")
await create_skillset(search_indexer_client)
print("Creating datasource...")
await create_datasource(search_indexer_client)
print("Creating indexer...")
await create_indexer(search_indexer_client)
print("Done")
async def upload_documents(blob_service_client: BlobServiceClient):
container_client = blob_service_client.get_container_client(os.getenv("AZURE_STORAGE_CONTAINER"))
document_paths = glob.glob(os.path.join(samples_path, "*.pdf"))
for document_path in document_paths:
async with container_client.get_blob_client(os.path.basename(document_path)) as blob_client:
if not await blob_client.exists():
with open(document_path, "rb") as data:
await blob_client.upload_blob(data=data)
async def get_storage_credential():
async with create_credential() as credential, StorageManagementClient(credential=credential, subscription_id=os.getenv("AZURE_SUBSCRIPTION_ID")) as storage_client:
result = await storage_client.storage_accounts.list_keys(resource_group_name=os.getenv("AZURE_RESOURCE_GROUP"), account_name=os.getenv("AZURE_STORAGE_ACCOUNT"))
return AzureNamedKeyCredential(name="key1", key=result.keys[0].value)
async def get_storage_connection_string():
_, key = (await get_storage_credential()).named_key
return f"DefaultEndpointsProtocol=https;AccountName={os.getenv('AZURE_STORAGE_ACCOUNT')};AccountKey={key};EndpointSuffix=core.windows.net;"
def get_serverless_deployment():
workspace_ml_client = MLClient(
create_sync_credential(),
subscription_id=os.getenv("AZURE_SUBSCRIPTION_ID"),
resource_group_name=os.getenv("AZURE_RESOURCE_GROUP"),
workspace_name=os.getenv("AZUREAI_PROJECT_NAME")
)
scoring_uri = workspace_ml_client.serverless_endpoints.get(name=os.getenv("AZUREAI_SERVERLESS_ENDPOINT_NAME")).scoring_uri
authentication_key = workspace_ml_client.serverless_endpoints.get_keys(name=os.getenv("AZUREAI_SERVERLESS_ENDPOINT_NAME")).primary_key
model_name = os.getenv("AZUREAI_SERVERLESS_MODEL").split("/")[-1]
return (scoring_uri, authentication_key, model_name)
async def create_index(search_index_client: SearchIndexClient):
scoring_uri, authentication_key, model_name = get_serverless_deployment()
index = SearchIndex(
name=os.getenv("AZURE_SEARCH_INDEX"),
fields=[
SearchableField(
name="id",
type=SearchFieldDataType.String,
key=True,
filterable=True,
analyzer_name=LexicalAnalyzerName.KEYWORD
),
SearchableField(
name="document_id",
type=SearchFieldDataType.String,
key=False,
filterable=True,
analyzer_name=LexicalAnalyzerName.KEYWORD
),
SearchField(
name="content",
type=SearchFieldDataType.String,
searchable=True,
filterable=False,
facetable=False,
sortable=False
),
SearchField(
name="embedding",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
stored=False,
vector_search_dimensions=1024,
vector_search_profile_name="approximateProfile"
),
SimpleField(
name="metadata_storage_path",
type=SearchFieldDataType.String,
filterable=True
)
],
vector_search=VectorSearch(
profiles=[
VectorSearchProfile(
name="approximateProfile",
algorithm_configuration_name="approximateConfiguration",
vectorizer="cohere",
compression_configuration_name="scalarQuantization"
)
],
algorithms=[
HnswAlgorithmConfiguration(name="approximateConfiguration")
],
vectorizers=[
AzureMachineLearningVectorizer(
name="cohere",
aml_parameters=AzureMachineLearningParameters(
scoring_uri=scoring_uri,
authentication_key=authentication_key,
model_name=model_name
)
)
],
compressions=[
ScalarQuantizationCompressionConfiguration(name="scalarQuantization")
]
),
semantic_search=SemanticSearch(
default_configuration_name="semantic-config",
configurations=[
SemanticConfiguration(
name="semantic-config",
prioritized_fields=SemanticPrioritizedFields(
content_fields=[SemanticField(field_name="content")]
)
)
]
)
)
await search_index_client.create_or_update_index(index)
async def create_skillset(search_indexer_client: SearchIndexerClient):
scoring_uri, authentication_key, _ = get_serverless_deployment()
skillset = SearchIndexerSkillset(
name=os.getenv("AZURE_SEARCH_SKILLSET"),
skills=[
AzureMachineLearningSkill(
description="Skill to generate embeddings via Cohere",
context="/document/pages/*",
scoring_uri=f"{scoring_uri}/v1/embed",
authentication_key=authentication_key,
inputs=[
InputFieldMappingEntry(name="texts", source="=[$(/document/pages/*)]"),
InputFieldMappingEntry(name="input_type", source="='search_document'"),
InputFieldMappingEntry(
name="truncate", source="='NONE'"
), # Trim end of input if necessary
InputFieldMappingEntry(name="embedding_types", source="=['float']")
],
outputs=[
OutputFieldMappingEntry(name="embeddings", target_name="aml_vector_object")
]
),
SplitSkill(
description="Split skill to chunk documents",
text_split_mode="pages",
context="/document",
maximum_page_length=2000,
page_overlap_length=500,
inputs=[
InputFieldMappingEntry(name="text", source="/document/content"),
],
outputs=[
OutputFieldMappingEntry(name="textItems", target_name="pages")
]
)
],
index_projections=SearchIndexerIndexProjections(
selectors=[
SearchIndexerIndexProjectionSelector(
target_index_name=os.getenv("AZURE_SEARCH_INDEX"),
parent_key_field_name="document_id",
source_context="/document/pages/*",
mappings=[
InputFieldMappingEntry(
name="embedding",
source="/document/pages/*/aml_vector_object/float/0"
),
InputFieldMappingEntry(
name="content",
source="/document/pages/*"
),
InputFieldMappingEntry(
name="metadata_storage_path",
source="/document/metadata_storage_path"
)
]
)
],
parameters=SearchIndexerIndexProjectionsParameters(projection_mode=IndexProjectionMode.SKIP_INDEXING_PARENT_DOCUMENTS)
)
)
await search_indexer_client.create_or_update_skillset(skillset)
async def create_datasource(search_indexer_client: SearchIndexerClient):
datasource = SearchIndexerDataSourceConnection(
name=os.getenv("AZURE_SEARCH_DATASOURCE"),
type=SearchIndexerDataSourceType.AZURE_BLOB,
connection_string=await get_storage_connection_string(),
container=SearchIndexerDataContainer(name=os.getenv("AZURE_STORAGE_CONTAINER"))
)
await search_indexer_client.create_or_update_data_source_connection(datasource)
async def create_indexer(search_indexer_client: SearchIndexerClient):
indexer = SearchIndexer(
name=os.getenv("AZURE_SEARCH_INDEXER"),
data_source_name=os.getenv("AZURE_SEARCH_DATASOURCE"),
target_index_name=os.getenv("AZURE_SEARCH_INDEX"),
skillset_name=os.getenv("AZURE_SEARCH_SKILLSET")
)
await search_indexer_client.create_or_update_indexer(indexer)
if __name__ == "__main__":
asyncio.run(main())