utils/prompt_loader.py (52 lines of code) (raw):
# utils/prompt_loader.py
from pathlib import Path
from jinja2 import Environment, FileSystemLoader, TemplateNotFound
class PromptLoader:
def __init__(self):
self.env = Environment(
loader=FileSystemLoader(Path(__file__).parent.parent / 'prompt_templates/sql_generation'),
trim_blocks=True,
lstrip_blocks=True
)
def get_prompt(
self,
db_type: str,
context: dict,
limit: int = 100,
user_custom_prompt: str | None = None # New Custom Parameters
) -> str:
try:
template = self.env.get_template(f"{db_type.lower()}_prompt.jinja")
except TemplateNotFound:
template = self.env.get_template("base_prompt.jinja")
# Inject custom prompts into the context
context.update({
'limit_clause': self._get_limit_clause(db_type),
'optimization_rules': self._get_optimization_rules(db_type),
'user_custom_prompt': user_custom_prompt,
'limit': limit
})
return template.render(context)
def _get_limit_clause(self, db_type: str) -> str:
clauses = {
'mysql': "LIMIT n",
'oracle': "ROWNUM <= n",
'sqlserver': "TOP n",
'hologres': "LIMIT n"
}
return clauses.get(db_type.lower(), "LIMIT 100")
def _get_optimization_rules(self, db_type: str) -> str:
rules = {
'hologres': "- 分析使用EXPLAIN ANALYZE获得的执行计划"
}
return rules.get(db_type.lower(), "")
def test_prompt_loading():
loader = PromptLoader()
# TEST MySQL
mysql_context = {
'meta_data': 'mock_metadata',
'query': 'mock_query',
'db_type': 'hologres'
}
mysql_prompt = loader.get_prompt('mysql', mysql_context)
print("MySQL Prompt Output:\n", mysql_prompt)
assert "LIMIT n" in mysql_prompt
if __name__ == '__main__':
test_prompt_loading()