in nl2sql_src/metadata_update_bq.py [0:0]
def generate_metadata(project_id,
location,
dataset_name,
model_name="text-bison-32k",
model_parameters=None):
vertexai.init(project=project_id, location=location)
parameters = {
"max_output_tokens": 8192,
"temperature": 0.9,
"top_p": 1
}
model = TextGenerationModel.from_pretrained(model_name)
client = bigquery.Client(project=project_id)
data = {}
for table_ref in client.list_tables(dataset_name):
try:
table = client.get_table(table_ref)
except Exception as e:
print(f"Error getting table {table_ref}: {e}")
continue
table_name = table.reference.table_id
table_description = table.description if table.description else ""
if not table_description:
prompt = f"""Describe the data in table '{table_name}'.
Do not include any markdown symmbols or special characters
while generating the description"""
response = model.predict(prompt=prompt)
table_description = response.text.strip()
table_description = table_description.replace('**', '')
columns = {}
schema = table.schema
prompt = f"Describe the following columns in table '{table_name}'"
for field in schema:
prompt += f"- Column Name: {field.name}, Type: {field.field_type}"
prompt += "\nPlease generate response in 'column name : Description'"
prompt += " format. Do not include any markdown symmbols or special "
prompt += "characters while generating the description"
response = model.predict(prompt=prompt, **parameters)
print(response)
response.text = response.text.replace('**', '')
column_descriptions = [
line for line in response.text.strip().split("\n") if line.strip()
]
for i, field in enumerate(schema):
if i < len(column_descriptions):
column_description = column_descriptions[i]
else:
column_description = "" # Handle missing descriptions
column_dict = {
'Name': field.name,
'Type': field.field_type,
'Description': column_description,
'Examples': f"Sample value for {field.name}"
}
columns[field.name] = column_dict
data[table_name] = {
"Name": table_name,
"Description": table_description,
"Columns": columns
}
return data