Ab-Romia commited on
Commit
d6c52da
Β·
verified Β·
1 Parent(s): d0501a3

Update app/rag_setup.py

Browse files
Files changed (1) hide show
  1. app/rag_setup.py +45 -117
app/rag_setup.py CHANGED
@@ -141,83 +141,58 @@ class OpenRouterLLM:
141
  self.client_ready = False
142
  return
143
 
144
- # Don't test connection during initialization to avoid blocking startup
145
- # Testing will happen when first needed
146
- self.client_ready = True
147
- logger.info("βœ… OpenRouter LLM initialized (connection will be tested on first use)")
148
- logger.info("=" * 60)
149
-
150
- def test_connection(self) -> dict:
151
- """Test the API connection with minimal request."""
152
- logger.info("πŸ” Testing OpenRouter API connection...")
153
-
154
- # Use a very simple test prompt and minimal tokens
155
- test_prompt = "Hi"
156
- response = self._make_api_request(
157
- prompt=test_prompt,
158
- max_tokens=5, # Minimal tokens for faster response
159
- timeout=15, # Short timeout for testing
160
- is_test=True # Flag to indicate this is a test
161
- )
162
 
163
- if response and "error" not in response:
164
- logger.info("βœ… OpenRouter connection test successful")
165
- return {"valid": True, "message": "API key is valid and working"}
166
- else:
167
- error_msg = response.get("error", "Unknown error") if response else "No response"
168
- logger.error(f"❌ OpenRouter connection test failed: {error_msg}")
169
- return {"valid": False, "message": f"API test failed: {error_msg}"}
170
 
171
- def _make_api_request(self, prompt: str, max_tokens: int = 2000, timeout: int = None, is_test: bool = False) -> dict:
172
- """Make a direct HTTP request to OpenRouter API with improved timeout handling."""
173
 
174
- # Calculate dynamic timeout based on context
175
  if timeout is None:
176
- if is_test:
177
- timeout = 15 # Short timeout for tests
178
- else:
179
- base_timeout = 60 # Reduced base timeout
180
- # More conservative timeout calculation
181
- token_timeout = max(10, max_tokens // 200) # ~1 second per 200 tokens
182
- prompt_timeout = max(5, len(prompt) // 2000) # ~1 second per 2000 characters
183
- timeout = min(base_timeout + token_timeout + prompt_timeout, 300) # Cap at 5 minutes
184
 
185
  logger.info(f"🌐 Making API request to OpenRouter")
186
  logger.info(f"πŸ“ Prompt length: {len(prompt)} characters")
187
  logger.info(f"🎯 Max tokens: {max_tokens}")
188
  logger.info(f"⏱️ Timeout: {timeout}s")
189
- logger.info(f"πŸ§ͺ Is test: {is_test}")
190
 
191
  headers = {
192
  "Authorization": f"Bearer {self.api_key}",
193
  "Content-Type": "application/json",
194
  "HTTP-Referer": "https://github.com/Ab-Romia/ContextIQ-RAG",
195
- "X-Title": "Context Aware AI",
196
- "User-Agent": "ContextIQ/1.0" # Add user agent
197
  }
198
 
199
- # Optimize payload for faster responses, especially for tests
200
  payload = {
201
  "model": self.model,
202
  "messages": [{"role": "user", "content": prompt}],
203
  "max_tokens": max_tokens,
 
 
204
  "stream": False,
 
 
 
205
  }
206
 
207
- # Add performance optimizations for tests
208
- if is_test:
209
- payload.update({
210
- "temperature": 0.1, # Lower temperature for faster, more deterministic responses
211
- "top_p": 0.5, # Lower top_p for faster generation
212
- })
213
- else:
214
- payload.update({
215
- "temperature": 0.7,
216
- "top_p": 0.9,
217
- "presence_penalty": 0.1,
218
- "frequency_penalty": 0.1,
219
- })
220
-
221
  # Log the request payload (without sensitive data)
222
  safe_payload = payload.copy()
223
  safe_payload["messages"] = [{"role": "user", "content": f"[CONTENT: {len(prompt)} chars]"}]
@@ -226,35 +201,12 @@ class OpenRouterLLM:
226
  try:
227
  start_time = time.time()
228
 
229
- # Use session with connection pooling and improved settings
230
  with requests.Session() as session:
231
- # Configure session for better performance
232
- session.headers.update(headers)
233
-
234
- # Configure adapters for retry and connection pooling
235
- from requests.adapters import HTTPAdapter
236
- from urllib3.util.retry import Retry
237
-
238
- # Define retry strategy for transient errors
239
- retry_strategy = Retry(
240
- total=2, # Reduced retries for faster failure
241
- backoff_factor=0.5,
242
- status_forcelist=[429, 500, 502, 503, 504, 408], # Include 408 timeout
243
- allowed_methods=["POST"]
244
- )
245
-
246
- adapter = HTTPAdapter(
247
- max_retries=retry_strategy,
248
- pool_connections=10,
249
- pool_maxsize=10
250
- )
251
- session.mount("http://", adapter)
252
- session.mount("https://", adapter)
253
-
254
  response = session.post(
255
  self.api_url,
 
256
  json=payload,
257
- timeout=(10, timeout), # (connection timeout, read timeout)
258
  )
259
 
260
  request_time = time.time() - start_time
@@ -280,47 +232,27 @@ class OpenRouterLLM:
280
  if completion_tokens >= max_tokens * 0.95: # If we used 95% of max tokens
281
  logger.warning(f"⚠️ Response may be truncated (used {completion_tokens}/{max_tokens} tokens)")
282
 
283
- if not is_test: # Don't log content for tests
284
- content_preview = content[:300] + "..." if len(content) > 300 else content
285
- logger.info(f"πŸ“„ Response preview: {content_preview}")
286
 
287
  return response_data
288
-
289
- elif response.status_code == 408:
290
- logger.error(f"⏱️ Request timed out (408) after {request_time:.2f}s")
291
- return {"error": f"Request timed out. Try again or reduce the request complexity."}
292
-
293
- elif response.status_code == 429:
294
- logger.error(f"🚦 Rate limited (429)")
295
- return {"error": "Rate limited. Please wait a moment and try again."}
296
-
297
- elif response.status_code in [500, 502, 503, 504]:
298
- logger.error(f"πŸ₯ Server error ({response.status_code})")
299
- return {"error": f"Server error ({response.status_code}). Please try again later."}
300
-
301
  else:
302
  logger.error(f"❌ API request failed with status {response.status_code}")
303
  logger.error(f"πŸ“„ Response text: {response.text}")
304
  return {"error": f"HTTP {response.status_code}: {response.text}"}
305
 
306
- except requests.exceptions.Timeout as e:
307
- logger.error(f"⏱️ Request timed out after {timeout}s: {e}")
308
- return {"error": f"Request timed out after {timeout}s. Please try again or reduce the request size."}
309
-
310
  except requests.exceptions.ConnectionError as e:
311
  logger.error(f"🌐 Connection error: {e}")
312
- return {"error": f"Connection error. Please check your internet connection."}
313
-
314
- except requests.exceptions.HTTPError as e:
315
- logger.error(f"🌐 HTTP error: {e}")
316
- return {"error": f"HTTP error: {str(e)}"}
317
-
318
  except Exception as e:
319
- logger.error(f"❌ Unexpected error in API request: {e}")
320
- return {"error": f"Unexpected error: {str(e)}"}
321
 
322
  def generate_content(self, prompt: str, max_tokens: int = 2000) -> str:
323
- """Generate content with improved error handling and timeout management."""
324
  logger.info("=" * 80)
325
  logger.info("🧠 LLM CONTENT GENERATION STARTED")
326
  logger.info("=" * 80)
@@ -371,9 +303,9 @@ class OpenRouterLLM:
371
  logger.error(error_msg)
372
  return error_msg
373
 
374
- max_retries = 2 # Reduced retries for faster failure
375
  retry_count = 0
376
- base_wait_time = 1 # Reduced wait time
377
 
378
  while retry_count <= max_retries:
379
  try:
@@ -385,7 +317,7 @@ class OpenRouterLLM:
385
 
386
  if retry_count > 0:
387
  # Reduce max_tokens on retries for faster responses
388
- current_max_tokens = max(500, max_tokens - (retry_count * 500))
389
  logger.info(f"πŸ”§ Retry attempt - reducing max_tokens to {current_max_tokens}")
390
 
391
  response = self._make_api_request(prompt, max_tokens=current_max_tokens, timeout=timeout)
@@ -397,9 +329,8 @@ class OpenRouterLLM:
397
  if "timeout" in error_msg.lower() or "408" in error_msg:
398
  logger.warning(f"⏱️ Timeout error on attempt {retry_count + 1}")
399
  if retry_count < max_retries:
400
- retry_count += 1
401
  continue
402
- elif "429" in error_msg or "rate limit" in error_msg.lower():
403
  logger.warning(f"🚦 Rate limit error on attempt {retry_count + 1}")
404
  wait_time = base_wait_time * (2 ** retry_count)
405
  logger.info(f"⏳ Waiting {wait_time}s for rate limit cooldown...")
@@ -459,7 +390,7 @@ class OpenRouterLLM:
459
  logger.info("=" * 80)
460
  return final_error
461
 
462
- wait_time = base_wait_time * retry_count
463
  logger.info(f"⏳ Waiting {wait_time:.1f}s before retry...")
464
  time.sleep(wait_time)
465
 
@@ -483,9 +414,6 @@ except Exception as e:
483
  class DummyLLM:
484
  def generate_content(self, prompt: str) -> str:
485
  return f"❌ AI model is not available. Initialization error: {str(e)}"
486
-
487
- def test_connection(self) -> dict:
488
- return {"valid": False, "message": f"Model not available: {str(e)}"}
489
 
490
  generation_model = DummyLLM()
491
  logger.warning("⚠️ Using dummy LLM due to initialization failure")
 
141
  self.client_ready = False
142
  return
143
 
144
+ # Test the connection with minimal tokens
145
+ try:
146
+ logger.info("πŸ” Testing OpenRouter connection...")
147
+ test_response = self._make_api_request("Hello", max_tokens=5)
148
+ if test_response and "error" not in test_response:
149
+ logger.info("βœ… OpenRouter connection test successful")
150
+ self.client_ready = True
151
+ else:
152
+ logger.error(f"❌ OpenRouter connection test failed: {test_response}")
153
+ self.client_ready = False
154
+ except Exception as e:
155
+ logger.error(f"❌ OpenRouter connection test failed: {e}")
156
+ self.client_ready = False
 
 
 
 
 
157
 
158
+ logger.info("=" * 60)
 
 
 
 
 
 
159
 
160
+ def _make_api_request(self, prompt: str, max_tokens: int = 2000, timeout: int = None) -> dict:
161
+ """Make a direct HTTP request to OpenRouter API with configurable token limits."""
162
 
163
+ # Calculate dynamic timeout based on max_tokens and prompt length
164
  if timeout is None:
165
+ base_timeout = 120
166
+ # More tokens = longer generation time
167
+ token_timeout = max(20, max_tokens // 100) # ~1 second per 100 tokens
168
+ prompt_timeout = max(10, len(prompt) // 1000) # ~1 second per 2000 characters
169
+ timeout = min(base_timeout + token_timeout + prompt_timeout, 600) # Cap at 5 minutes
 
 
 
170
 
171
  logger.info(f"🌐 Making API request to OpenRouter")
172
  logger.info(f"πŸ“ Prompt length: {len(prompt)} characters")
173
  logger.info(f"🎯 Max tokens: {max_tokens}")
174
  logger.info(f"⏱️ Timeout: {timeout}s")
 
175
 
176
  headers = {
177
  "Authorization": f"Bearer {self.api_key}",
178
  "Content-Type": "application/json",
179
  "HTTP-Referer": "https://github.com/Ab-Romia/ContextIQ-RAG",
180
+ "X-Title": "Context Aware AI"
 
181
  }
182
 
183
+ # Optimize payload for longer responses
184
  payload = {
185
  "model": self.model,
186
  "messages": [{"role": "user", "content": prompt}],
187
  "max_tokens": max_tokens,
188
+ "temperature": 0.7,
189
+ "top_p": 0.9,
190
  "stream": False,
191
+ # Add parameters to encourage complete responses
192
+ "presence_penalty": 0.1, # Slight penalty for repetition
193
+ "frequency_penalty": 0.1, # Slight penalty for frequency
194
  }
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  # Log the request payload (without sensitive data)
197
  safe_payload = payload.copy()
198
  safe_payload["messages"] = [{"role": "user", "content": f"[CONTENT: {len(prompt)} chars]"}]
 
201
  try:
202
  start_time = time.time()
203
 
 
204
  with requests.Session() as session:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  response = session.post(
206
  self.api_url,
207
+ headers=headers,
208
  json=payload,
209
+ timeout=timeout
210
  )
211
 
212
  request_time = time.time() - start_time
 
232
  if completion_tokens >= max_tokens * 0.95: # If we used 95% of max tokens
233
  logger.warning(f"⚠️ Response may be truncated (used {completion_tokens}/{max_tokens} tokens)")
234
 
235
+ content_preview = content[:300] + "..." if len(content) > 300 else content
236
+ logger.info(f"πŸ“„ Response preview: {content_preview}")
 
237
 
238
  return response_data
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  else:
240
  logger.error(f"❌ API request failed with status {response.status_code}")
241
  logger.error(f"πŸ“„ Response text: {response.text}")
242
  return {"error": f"HTTP {response.status_code}: {response.text}"}
243
 
244
+ except requests.exceptions.Timeout:
245
+ logger.error(f"⏱️ API request timed out after {timeout}s")
246
+ return {"error": f"Request timed out after {timeout}s. Try reducing the context length or max tokens."}
 
247
  except requests.exceptions.ConnectionError as e:
248
  logger.error(f"🌐 Connection error: {e}")
249
+ return {"error": f"Connection error: {str(e)}"}
 
 
 
 
 
250
  except Exception as e:
251
+ logger.error(f"❌ API request failed: {e}")
252
+ return {"error": str(e)}
253
 
254
  def generate_content(self, prompt: str, max_tokens: int = 2000) -> str:
255
+ """Generate content with configurable token limits."""
256
  logger.info("=" * 80)
257
  logger.info("🧠 LLM CONTENT GENERATION STARTED")
258
  logger.info("=" * 80)
 
303
  logger.error(error_msg)
304
  return error_msg
305
 
306
+ max_retries = 3
307
  retry_count = 0
308
+ base_wait_time = 2
309
 
310
  while retry_count <= max_retries:
311
  try:
 
317
 
318
  if retry_count > 0:
319
  # Reduce max_tokens on retries for faster responses
320
+ current_max_tokens = max(1000, max_tokens - (retry_count * 500))
321
  logger.info(f"πŸ”§ Retry attempt - reducing max_tokens to {current_max_tokens}")
322
 
323
  response = self._make_api_request(prompt, max_tokens=current_max_tokens, timeout=timeout)
 
329
  if "timeout" in error_msg.lower() or "408" in error_msg:
330
  logger.warning(f"⏱️ Timeout error on attempt {retry_count + 1}")
331
  if retry_count < max_retries:
 
332
  continue
333
+ elif "429" in error_msg:
334
  logger.warning(f"🚦 Rate limit error on attempt {retry_count + 1}")
335
  wait_time = base_wait_time * (2 ** retry_count)
336
  logger.info(f"⏳ Waiting {wait_time}s for rate limit cooldown...")
 
390
  logger.info("=" * 80)
391
  return final_error
392
 
393
+ wait_time = base_wait_time * retry_count + (retry_count * 0.5)
394
  logger.info(f"⏳ Waiting {wait_time:.1f}s before retry...")
395
  time.sleep(wait_time)
396
 
 
414
  class DummyLLM:
415
  def generate_content(self, prompt: str) -> str:
416
  return f"❌ AI model is not available. Initialization error: {str(e)}"
 
 
 
417
 
418
  generation_model = DummyLLM()
419
  logger.warning("⚠️ Using dummy LLM due to initialization failure")