tools/hologres_excute_sql.py (112 lines of code) (raw):

from dify_plugin import Tool from typing import Any from collections.abc import Generator from dify_plugin.entities.tool import ToolInvokeMessage from utils.alchemy_db_client import execute_sql import json from datetime import datetime, date from decimal import Decimal import csv from io import StringIO class HologresExcuteSqlTool(Tool): def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]: # Get the SQL statement passed in sql = tool_parameters.get("sql") if not sql: raise ValueError("SQL statement cannot be empty") # Improved risk detection if self._contains_risk_commands(sql): raise ValueError("SQL statement contains risks") # Get database connection parameters db_type = tool_parameters.get("db_type") host = tool_parameters.get("host") port = tool_parameters.get("port") database = tool_parameters.get("db_name") username = tool_parameters.get("username") password = tool_parameters.get("password") if not all([db_type, host, port, database, username, password]): raise ValueError("Database connection parameters cannot be empty") try: # Execute SQL statement (query or non-query) result = execute_sql( db_type, host, int(port), database, username, password, sql, "" ) # Handle empty results if isinstance(result, list) and not result: # Empty query result yield self.create_text_message("No data found") elif isinstance(result, dict) and "rowcount" in result and result["rowcount"] == 0: # No affected rows yield self.create_text_message("No data affected") result_format = tool_parameters.get("result_format", "json") if result_format == 'json': yield self.create_json_message({ "status": "success", "result": result } ) elif result_format == 'csv': yield from self._handle_csv(result) elif result_format == 'html': yield from self._handle_html(result) else: if result is not None: message_text = json.dumps( result, ensure_ascii=False, default=self._custom_serializer # Key modification point ) else: message_text = "No data found" yield self.create_text_message(message_text) except Exception as e: raise ValueError(f"Database operation failed: {str(e)}") def _handle_html(self, data: list[dict[str, Any]] | dict[str, Any] | None) -> Generator[ToolInvokeMessage, None, None]: """Generate HTML table message""" html_table = self._to_html_table(data) yield self.create_blob_message(html_table.encode('utf-8'), meta={'mime_type': 'text/html', 'filename': 'result.html'}) def _handle_csv(self, data: list[dict[str, Any]] | dict[str, Any] | None) -> Generator[ToolInvokeMessage]: """Generate CSV file message""" output = StringIO() # Write BOM (only first 3 bytes) output.write('\ufeff') # Add BOM writer = csv.writer(output) # Write header writer.writerow(data[0].keys()) # Write data rows (handle date serialization) for row in data: processed_row = [ self._custom_serializer(val) if isinstance(val, (date, datetime)) else val for val in row.values() ] writer.writerow(processed_row) # Note: utf-8-sig encoding automatically includes BOM, recommended to use this method yield self.create_blob_message( output.getvalue().encode('utf-8-sig'), # Key modification point ✅ meta={ 'mime_type': 'text/csv', 'filename': 'result.csv', 'encoding': 'utf-8-sig' # Explicitly declare encoding } ) def _to_html_table(self, data: list[dict]) -> str: """Generate standard HTML table""" html = ["<table border='1'>"] html.append("<tr>" + "".join(f"<th>{col}</th>" for col in data[0].keys()) + "</tr>") for row in data: html.append( "<tr>" + "".join(f"<td>{self._custom_serializer(val)}</td>" for val in row.values()) + "</tr>" ) html.append("</table>") return "".join(html) def _contains_risk_commands(self, sql: str) -> bool: import re risk_keywords = {"DROP", "DELETE", "TRUNCATE", "ALTER", "UPDATE", "INSERT"} # Remove comments sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL) sql = re.sub(r'--.*', '', sql) # Split statements statements = re.split(r';\s*', sql) for stmt in statements: stmt = stmt.strip() if not stmt: continue # Match first word (case insensitive) match = re.match(r'\s*([^\s]+)', stmt, re.IGNORECASE) if match: first_word = match.group(1).upper() if first_word in risk_keywords: return True return False def _custom_serializer(self, obj: Any) -> Any: """Handle common non-serializable database types""" if isinstance(obj, (datetime, date)): return obj.isoformat() # Convert to ISO8601 string elif isinstance(obj, Decimal): return float(obj) # Convert Decimal to float # Add other types that need to be handled (such as bytes) raise TypeError(f"Unserializable type {type(obj)}")