File size: 5,526 Bytes
1bd1563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""SQL validation and sanitization utilities."""

import re
from typing import List, Tuple

from logger.logging import get_logger

logger = get_logger(__name__)

# Dangerous SQL operations that should never appear in generated queries
BLOCKED_OPERATIONS = [
    "DROP",
    "DELETE",
    "UPDATE",
    "INSERT",
    "ALTER",
    "CREATE",
    "TRUNCATE",
    "EXEC",
    "EXECUTE",
    "GRANT",
    "REVOKE",
    "ATTACH",
    "DETACH",
    "VACUUM",
    "REINDEX",
    "PRAGMA",
]

# Allowed tables (must match schema.sql)
ALLOWED_TABLES = [
    "customers",
    "products",
    "orders",
    "order_items",
    "reviews",
    "inventory_log",
]


def validate_sql(sql: str, allowed_tables: List[str] = None) -> Tuple[bool, str]:
    """Validate a generated SQL query for safety.

    Returns:
        (is_valid, error_message)
    """
    if allowed_tables is None:
        allowed_tables = ALLOWED_TABLES

    sql_upper = sql.upper().strip()

    # Must start with SELECT or WITH (for CTEs)
    if not sql_upper.startswith("SELECT") and not sql_upper.startswith("WITH"):
        return False, "Only SELECT queries are allowed"

    # Check for blocked operations
    for op in BLOCKED_OPERATIONS:
        # Match as whole word to avoid false positives (e.g., "UPDATED_AT")
        pattern = rf"\b{op}\b"
        # Skip checking inside string literals by removing them first
        sql_no_strings = re.sub(r"'[^']*'", "", sql_upper)
        if re.search(pattern, sql_no_strings):
            return False, f"Blocked operation detected: {op}"

    # Check for multiple statements (SQL injection via semicolons)
    sql_no_strings = re.sub(r"'[^']*'", "", sql)
    if ";" in sql_no_strings.rstrip(";").rstrip():
        return False, "Multiple statements not allowed"

    # Check for comments (potential injection vector)
    if "--" in sql_no_strings or "/*" in sql_no_strings:
        return False, "SQL comments not allowed in generated queries"

    # Validate referenced tables
    # We parse FROM/JOIN clauses but must handle:
    #   - CTE names: WITH cte AS (...) SELECT ... FROM cte
    #   - Subquery aliases: FROM (SELECT ...) alias
    #   - EXTRACT syntax: EXTRACT(MONTH FROM col)
    #   - Function calls: FROM func(...)
    sql_clean = re.sub(r"'[^']*'", "", sql_upper)

    # Collect CTE names so we can exclude them
    cte_names = set()
    cte_pattern = r"\bWITH\s+(.*?)\bSELECT\b"
    cte_block = re.search(cte_pattern, sql_clean, re.DOTALL)
    if cte_block:
        for m in re.finditer(r"(\w+)\s+AS\s*\(", cte_block.group(1)):
            cte_names.add(m.group(1).lower())

    # Match table names after FROM/JOIN, but skip:
    #   - "(" after the name (subquery or function)
    #   - FROM preceded by EXTRACT/DISTINCT/etc. (not a table clause)
    table_pattern = r"\b(?:FROM|JOIN)\s+(\w+)(?!\s*\()"
    matches = re.findall(table_pattern, sql_clean)
    referenced_tables = set(t.lower() for t in matches)

    # Remove CTE names and common SQL keywords that aren't tables
    sql_keywords = {
        "select",
        "where",
        "and",
        "or",
        "not",
        "null",
        "as",
        "on",
        "in",
        "is",
        "by",
        "asc",
        "desc",
        "case",
        "when",
        "then",
        "else",
        "end",
        "between",
        "like",
        "having",
        "union",
        "all",
        "exists",
        "each",
        "lateral",
    }
    referenced_tables -= cte_names
    referenced_tables -= sql_keywords

    # Also remove any word captured from EXTRACT(... FROM column_name)
    extract_pattern = r"\bEXTRACT\s*\([^)]*\bFROM\s+(\w+)"
    for m in re.finditer(extract_pattern, sql_clean):
        referenced_tables.discard(m.group(1).lower())

    allowed_set = set(t.lower() for t in allowed_tables)
    invalid_tables = referenced_tables - allowed_set
    if invalid_tables:
        # Separate likely aliases (short, 1-3 chars) from truly unknown tables.
        # LLMs commonly alias "orders" as "o", "products" as "p", etc.
        likely_aliases = {t for t in invalid_tables if len(t) <= 3}
        truly_invalid = invalid_tables - likely_aliases

        if truly_invalid:
            return False, f"Referenced disallowed tables: {', '.join(truly_invalid)}"

        if likely_aliases:
            logger.info(f"SQL uses table aliases: {likely_aliases} (allowed)")

    return True, ""


def sanitize_sql(sql: str) -> str:
    """Clean up SQL for display/execution."""
    # Remove leading/trailing whitespace and semicolons
    sql = sql.strip().rstrip(";").strip()

    # Remove markdown code blocks if present
    if sql.startswith("```"):
        lines = sql.split("\n")
        lines = [l for l in lines if not l.strip().startswith("```")]
        sql = "\n".join(lines).strip()

    return sql


def extract_sql_from_response(text: str) -> str:
    """Extract SQL query from an LLM response that may contain markdown or explanations."""
    # Try to find SQL in code blocks
    code_block_pattern = r"```(?:sql)?\s*\n?(.*?)\n?```"
    matches = re.findall(code_block_pattern, text, re.DOTALL | re.IGNORECASE)
    if matches:
        return sanitize_sql(matches[0])

    # Look for SELECT statement in the text
    select_pattern = r"((?:WITH\s+.*?\s+AS\s*\(.*?\)\s*)?SELECT\s+.*?)(?:\n\n|\Z)"
    matches = re.findall(select_pattern, text, re.DOTALL | re.IGNORECASE)
    if matches:
        return sanitize_sql(matches[0])

    # Fallback: return the whole text sanitized
    return sanitize_sql(text)