genai-for-marketing/backend_apis/app/utils_codey.py (109 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility module for Codey releated demo.
"""
from google.cloud import bigquery
def get_tags_from_table(
datacatalog_client,
dataset_id: str,
table_id: str,
project_id: str,
tag_template_name: str
):
"""Gets the tags from a BigQuery table.
Args:
dataset_id:
The ID of the BigQuery dataset that contains the table.
table_id:
The ID of the BigQuery table.
project_id:
The ID of the Google Cloud project.
tag_template_name:
The name of the tag template.
Returns:
A string containing the tags for the table.
"""
# Lookup Data Catalog's Entry referring to the table.
resource_name = (
f"//bigquery.googleapis.com/projects/{project_id}"
f"/datasets/{dataset_id}/tables/{table_id}"
)
table_entry = datacatalog_client.lookup_entry(
request={"linked_resource": resource_name}
)
# Make the request
page_result = datacatalog_client.list_tags(parent=table_entry.name)
tags_str = ''
# Handle the response
for response in page_result:
if response.template == tag_template_name:
desc = response.fields["description"].string_value
data_type = response.fields["data_type"].string_value
pk = response.fields["is_primary_key"].bool_value
fk = response.fields["is_foreign_key"].bool_value
tags_str += ("Full table name: {} "
"- Column: {} "
"- Data Type: {} "
"- Primary Key: {} "
"- Foreign Key: {} "
"- Description: {}\n".format(
f'`{project_id}.{dataset_id}.{table_id}`',
response.column,
data_type,
pk,
fk,
desc))
return tags_str
def get_metadata_from_dataset(
bqclient,
datacatalog_client,
query: str,
project_id: str,
dataset_id: str,
tag_template_name: str
):
"""Gets the metadata for all tables in a BigQuery dataset.
Args:
query:
The BigQuery query to run to get the list of tables.
project_id:
The ID of the BigQuery project.
dataset_id:
The ID of the BigQuery dataset.
tag_template_name:
The name of the BigQuery tag template to use to get the table descriptions.
state_key:
The key to use to store the metadata in the Streamlit session state.
"""
# print("Gets the metadata once")
query_job = bqclient.query(query) # API request
rows = query_job.result()
metadata = []
for row in rows:
full_table_path = f'`{project_id}.{dataset_id}.{row.table_name}`'
table_metadata = f'[SCHEMA details for table {full_table_path}]\n'
table_metadata += get_tags_from_table(
datacatalog_client,
dataset_id,
row.table_name,
project_id,
tag_template_name)
metadata.append(table_metadata)
return metadata
def get_full_context_from_list(metadata: list):
"""Gets the full context from a list of metadata.
Args:
metadata: A list of metadata objects.
Returns:
A string containing the full context.
"""
context = ''
for i in metadata:
context += i
return context
def generate_prompt(
question: str,
metadata: list,
prompt: str,
project_id: str
):
"""Generates a prompt for a GoogleSQL query compatible with BigQuery.
Args:
question:
The question to answer.
metadata:
A list of dictionaries, where each dictionary describes a BigQuery table.
The dictionaries should have the following keys:
- name: The name of the table.
- schema: The schema of the table.
- description: A description of the table.
state_key:
The key to use to store the prompt in the session state.
Returns:
The prompt.
"""
PROMPT_PROJECT_ID = [project_id]*7
context = ''
for i in metadata:
context += i
return (f"{prompt.format(context,*PROMPT_PROJECT_ID)} \n"
f"[Q]: {question} \n"
"[SQL]:")
def generate_sql_and_query(
llm,
datacatalog_client,
prompt_template: str,
query_metadata: str,
question: str,
project_id: str,
dataset_id: str,
tag_template_name: str,
bqclient: bigquery.Client):
"""Generates a GoogleSQL query and executes it against a BigQuery dataset.
Args:
query:
The initial query text.
project_id:
The ID of the BigQuery project.
dataset_id:
The ID of the BigQuery dataset.
tag_template_name:
The name of the tag template to use for the query.
bqclient:
A BigQuery client object.
Returns:
A DataFrame containing the results of the query.
Raises:
NotFoundError: If the dataset or table is not found.
BadRequestError: If the query is invalid.
"""
metadata = get_metadata_from_dataset(
bqclient=bqclient,
datacatalog_client=datacatalog_client,
query=query_metadata,
project_id=project_id,
dataset_id=dataset_id,
tag_template_name=tag_template_name)
prompt = generate_prompt(
question,
metadata,
prompt_template,
project_id)
gen_code = llm.predict(
prompt = prompt,
max_output_tokens = 1024,
temperature=0.3
).text.replace("```","")
gen_code = gen_code[gen_code.find("SELECT"):]
result = []
result_job = bqclient.query(gen_code)
for row in result_job:
result.append(dict(row.items()))
return result, gen_code, prompt