| """SQL validation and sanitization utilities.""" |
|
|
| import re |
| from typing import List, Tuple |
|
|
| from logger.logging import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
| |
| BLOCKED_OPERATIONS = [ |
| "DROP", |
| "DELETE", |
| "UPDATE", |
| "INSERT", |
| "ALTER", |
| "CREATE", |
| "TRUNCATE", |
| "EXEC", |
| "EXECUTE", |
| "GRANT", |
| "REVOKE", |
| "ATTACH", |
| "DETACH", |
| "VACUUM", |
| "REINDEX", |
| "PRAGMA", |
| ] |
|
|
| |
| ALLOWED_TABLES = [ |
| "customers", |
| "products", |
| "orders", |
| "order_items", |
| "reviews", |
| "inventory_log", |
| ] |
|
|
|
|
| def validate_sql(sql: str, allowed_tables: List[str] = None) -> Tuple[bool, str]: |
| """Validate a generated SQL query for safety. |
| |
| Returns: |
| (is_valid, error_message) |
| """ |
| if allowed_tables is None: |
| allowed_tables = ALLOWED_TABLES |
|
|
| sql_upper = sql.upper().strip() |
|
|
| |
| if not sql_upper.startswith("SELECT") and not sql_upper.startswith("WITH"): |
| return False, "Only SELECT queries are allowed" |
|
|
| |
| for op in BLOCKED_OPERATIONS: |
| |
| pattern = rf"\b{op}\b" |
| |
| sql_no_strings = re.sub(r"'[^']*'", "", sql_upper) |
| if re.search(pattern, sql_no_strings): |
| return False, f"Blocked operation detected: {op}" |
|
|
| |
| sql_no_strings = re.sub(r"'[^']*'", "", sql) |
| if ";" in sql_no_strings.rstrip(";").rstrip(): |
| return False, "Multiple statements not allowed" |
|
|
| |
| if "--" in sql_no_strings or "/*" in sql_no_strings: |
| return False, "SQL comments not allowed in generated queries" |
|
|
| |
| |
| |
| |
| |
| |
| sql_clean = re.sub(r"'[^']*'", "", sql_upper) |
|
|
| |
| cte_names = set() |
| cte_pattern = r"\bWITH\s+(.*?)\bSELECT\b" |
| cte_block = re.search(cte_pattern, sql_clean, re.DOTALL) |
| if cte_block: |
| for m in re.finditer(r"(\w+)\s+AS\s*\(", cte_block.group(1)): |
| cte_names.add(m.group(1).lower()) |
|
|
| |
| |
| |
| table_pattern = r"\b(?:FROM|JOIN)\s+(\w+)(?!\s*\()" |
| matches = re.findall(table_pattern, sql_clean) |
| referenced_tables = set(t.lower() for t in matches) |
|
|
| |
| sql_keywords = { |
| "select", |
| "where", |
| "and", |
| "or", |
| "not", |
| "null", |
| "as", |
| "on", |
| "in", |
| "is", |
| "by", |
| "asc", |
| "desc", |
| "case", |
| "when", |
| "then", |
| "else", |
| "end", |
| "between", |
| "like", |
| "having", |
| "union", |
| "all", |
| "exists", |
| "each", |
| "lateral", |
| } |
| referenced_tables -= cte_names |
| referenced_tables -= sql_keywords |
|
|
| |
| extract_pattern = r"\bEXTRACT\s*\([^)]*\bFROM\s+(\w+)" |
| for m in re.finditer(extract_pattern, sql_clean): |
| referenced_tables.discard(m.group(1).lower()) |
|
|
| allowed_set = set(t.lower() for t in allowed_tables) |
| invalid_tables = referenced_tables - allowed_set |
| if invalid_tables: |
| |
| |
| likely_aliases = {t for t in invalid_tables if len(t) <= 3} |
| truly_invalid = invalid_tables - likely_aliases |
|
|
| if truly_invalid: |
| return False, f"Referenced disallowed tables: {', '.join(truly_invalid)}" |
|
|
| if likely_aliases: |
| logger.info(f"SQL uses table aliases: {likely_aliases} (allowed)") |
|
|
| return True, "" |
|
|
|
|
| def sanitize_sql(sql: str) -> str: |
| """Clean up SQL for display/execution.""" |
| |
| sql = sql.strip().rstrip(";").strip() |
|
|
| |
| if sql.startswith("```"): |
| lines = sql.split("\n") |
| lines = [l for l in lines if not l.strip().startswith("```")] |
| sql = "\n".join(lines).strip() |
|
|
| return sql |
|
|
|
|
| def extract_sql_from_response(text: str) -> str: |
| """Extract SQL query from an LLM response that may contain markdown or explanations.""" |
| |
| code_block_pattern = r"```(?:sql)?\s*\n?(.*?)\n?```" |
| matches = re.findall(code_block_pattern, text, re.DOTALL | re.IGNORECASE) |
| if matches: |
| return sanitize_sql(matches[0]) |
|
|
| |
| select_pattern = r"((?:WITH\s+.*?\s+AS\s*\(.*?\)\s*)?SELECT\s+.*?)(?:\n\n|\Z)" |
| matches = re.findall(select_pattern, text, re.DOTALL | re.IGNORECASE) |
| if matches: |
| return sanitize_sql(matches[0]) |
|
|
| |
| return sanitize_sql(text) |
|
|