| | """
|
| | FARA Backend Server for HuggingFace Space
|
| | Provides WebSocket communication and REST API for the React frontend
|
| | """
|
| |
|
| | import asyncio
|
| | import base64
|
| | import logging
|
| | import os
|
| |
|
| |
|
| | import sys
|
| | import tempfile
|
| | import uuid
|
| | from datetime import datetime
|
| | from typing import Dict, Optional
|
| |
|
| | import httpx
|
| | from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| | from fastapi.middleware.cors import CORSMiddleware
|
| | from fastapi.responses import JSONResponse
|
| | from playwright._impl._errors import TargetClosedError
|
| |
|
| | sys.path.insert(0, "/app")
|
| | from fara import FaraAgent
|
| | from fara.browser.browser_bb import BrowserBB
|
| |
|
| |
|
| | logging.basicConfig(level=logging.INFO)
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | MODAL_TRACE_STORAGE_URL = os.environ.get("MODAL_TRACE_STORAGE_URL", "")
|
| | MODAL_TOKEN_ID = os.environ.get("MODAL_TOKEN_ID", "")
|
| | MODAL_TOKEN_SECRET = os.environ.get("MODAL_TOKEN_SECRET", "")
|
| |
|
| |
|
| |
|
| | ENDPOINT_CONFIG = {
|
| | "model": os.environ.get("FARA_MODEL_NAME", "microsoft/Fara-7B"),
|
| | "base_url": os.environ.get("FARA_ENDPOINT_URL"),
|
| | "api_key": os.environ.get("FARA_API_KEY", "not-needed"),
|
| | "default_headers": {
|
| | "Modal-Key": MODAL_TOKEN_ID,
|
| | "Modal-Secret": MODAL_TOKEN_SECRET,
|
| | }
|
| | if MODAL_TOKEN_ID and MODAL_TOKEN_SECRET
|
| | else None,
|
| | }
|
| |
|
| |
|
| | AVAILABLE_MODELS = ["microsoft/Fara-7B"]
|
| |
|
| | app = FastAPI(title="FARA Backend")
|
| |
|
| |
|
| | app.add_middleware(
|
| | CORSMiddleware,
|
| | allow_origins=["*"],
|
| | allow_credentials=True,
|
| | allow_methods=["*"],
|
| | allow_headers=["*"],
|
| | )
|
| |
|
| |
|
| | active_connections: Dict[str, WebSocket] = {}
|
| | active_sessions: Dict[str, "FaraSession"] = {}
|
| |
|
| |
|
| | class FaraSession:
|
| | """Manages a single FARA agent session"""
|
| |
|
| | def __init__(self, trace_id: str, websocket: WebSocket):
|
| | self.trace_id = trace_id
|
| | self.websocket = websocket
|
| | self.agent: Optional[FaraAgent] = None
|
| | self.browser_manager: Optional[BrowserBB] = None
|
| | self.screenshots_dir: Optional[str] = None
|
| | self.is_running = False
|
| | self.should_stop = False
|
| | self.step_count = 0
|
| | self.start_time: Optional[datetime] = None
|
| | self.total_input_tokens = 0
|
| | self.total_output_tokens = 0
|
| |
|
| | async def initialize(self, start_page: str = "https://www.bing.com/"):
|
| | """Initialize the browser and agent"""
|
| |
|
| | self.screenshots_dir = tempfile.mkdtemp(prefix="fara_screenshots_")
|
| |
|
| |
|
| | self.browser_manager = BrowserBB(
|
| | headless=True,
|
| | viewport_height=900,
|
| | viewport_width=1440,
|
| | page_script_path=None,
|
| | browser_channel="chromium",
|
| | browser_data_dir=None,
|
| | downloads_folder=self.screenshots_dir,
|
| | to_resize_viewport=True,
|
| | single_tab_mode=True,
|
| | animate_actions=False,
|
| | use_browser_base=False,
|
| | logger=logger,
|
| | )
|
| |
|
| | self.agent = FaraAgent(
|
| | browser_manager=self.browser_manager,
|
| | client_config=ENDPOINT_CONFIG,
|
| | start_page=start_page,
|
| | downloads_folder=self.screenshots_dir,
|
| | save_screenshots=True,
|
| | max_rounds=50,
|
| | )
|
| |
|
| | await self.agent.initialize()
|
| | return True
|
| |
|
| | async def send_event(self, event: dict):
|
| | """Send event to the connected WebSocket"""
|
| | try:
|
| | await self.websocket.send_json(event)
|
| | except Exception as e:
|
| | logger.error(f"Error sending event: {e}")
|
| |
|
| | async def get_screenshot_base64(self) -> Optional[str]:
|
| | """Get the current browser screenshot as base64"""
|
| | if self.agent:
|
| | try:
|
| |
|
| | page = self._get_active_page()
|
| | if page:
|
| | screenshot_bytes = (
|
| | await self.agent._playwright_controller.get_screenshot(page)
|
| | )
|
| | return f"data:image/png;base64,{base64.b64encode(screenshot_bytes).decode()}"
|
| | except TargetClosedError:
|
| | logger.warning(
|
| | "Page closed while getting screenshot, attempting recovery..."
|
| | )
|
| | page = self._get_active_page()
|
| | if page:
|
| | try:
|
| | screenshot_bytes = (
|
| | await self.agent._playwright_controller.get_screenshot(page)
|
| | )
|
| | return f"data:image/png;base64,{base64.b64encode(screenshot_bytes).decode()}"
|
| | except Exception as e:
|
| | logger.error(f"Recovery screenshot failed: {e}")
|
| | except Exception as e:
|
| | logger.error(f"Error getting screenshot: {e}")
|
| | return None
|
| |
|
| | def _get_active_page(self):
|
| | """Get the currently active page from the browser context"""
|
| | if (
|
| | self.agent
|
| | and self.agent.browser_manager
|
| | and self.agent.browser_manager._context
|
| | ):
|
| | pages = self.agent.browser_manager._context.pages
|
| | if pages:
|
| |
|
| | return pages[-1]
|
| | return self.agent._page if self.agent else None
|
| |
|
| | async def run_task(self, instruction: str, model_id: str):
|
| | """Run a task and stream results via WebSocket"""
|
| | self.is_running = True
|
| | self.should_stop = False
|
| | self.step_count = 0
|
| | self.start_time = datetime.now()
|
| | self.total_input_tokens = 0
|
| | self.total_output_tokens = 0
|
| |
|
| | try:
|
| |
|
| | await self.send_event(
|
| | {
|
| | "type": "agent_start",
|
| | "agentTrace": {
|
| | "id": self.trace_id,
|
| | "instruction": instruction,
|
| | "modelId": model_id,
|
| | "timestamp": self.start_time.isoformat(),
|
| | "isRunning": True,
|
| | "traceMetadata": {
|
| | "traceId": self.trace_id,
|
| | "inputTokensUsed": 0,
|
| | "outputTokensUsed": 0,
|
| | "duration": 0,
|
| | "numberOfSteps": 0,
|
| | "maxSteps": 50,
|
| | "completed": False,
|
| | },
|
| | },
|
| | }
|
| | )
|
| |
|
| |
|
| | await self.initialize()
|
| |
|
| |
|
| | initial_screenshot = await self.get_screenshot_base64()
|
| |
|
| |
|
| | await self._run_agent_with_streaming(instruction)
|
| |
|
| | except Exception as e:
|
| | logger.exception("Error running agent task")
|
| | await self.send_event({"type": "agent_error", "error": str(e)})
|
| | finally:
|
| | self.is_running = False
|
| | await self.close()
|
| |
|
| | async def _run_agent_with_streaming(self, user_message: str):
|
| | """Run the agent and stream each step to the frontend"""
|
| | agent = self.agent
|
| |
|
| |
|
| | await agent.initialize()
|
| | assert agent._page is not None, "Page should be initialized"
|
| |
|
| |
|
| | scaled_screenshot = await agent._get_scaled_screenshot()
|
| |
|
| | if agent.save_screenshots:
|
| | await agent._playwright_controller.get_screenshot(
|
| | agent._page,
|
| | path=os.path.join(
|
| | agent.downloads_folder, f"screenshot{agent._num_actions}.png"
|
| | ),
|
| | )
|
| |
|
| |
|
| | from fara.types import ImageObj, UserMessage
|
| |
|
| | agent._chat_history.append(
|
| | UserMessage(
|
| | content=[ImageObj.from_pil(scaled_screenshot), user_message],
|
| | is_original=True,
|
| | )
|
| | )
|
| |
|
| | final_answer = "<no_answer>"
|
| | is_stop_action = False
|
| |
|
| | for i in range(agent.max_rounds):
|
| | if self.should_stop:
|
| |
|
| | await self.send_event(
|
| | {
|
| | "type": "agent_complete",
|
| | "traceMetadata": self._get_metadata(),
|
| | "final_state": "stopped",
|
| | }
|
| | )
|
| | return
|
| |
|
| | is_first_round = i == 0
|
| | step_start_time = datetime.now()
|
| |
|
| |
|
| | if not agent.browser_manager._captcha_event.is_set():
|
| | logger.info("Waiting 60s for captcha to finish...")
|
| | captcha_solved = await agent.wait_for_captcha_with_timeout(60)
|
| | if (
|
| | not captcha_solved
|
| | and not agent.browser_manager._captcha_event.is_set()
|
| | ):
|
| | raise RuntimeError("Captcha timed out")
|
| |
|
| | try:
|
| |
|
| | function_call, raw_response = await agent.generate_model_call(
|
| | is_first_round, scaled_screenshot if is_first_round else None
|
| | )
|
| |
|
| |
|
| | thoughts, action_dict = agent._parse_thoughts_and_action(raw_response)
|
| | action_args = action_dict.get("arguments", {})
|
| | action = action_args["action"]
|
| |
|
| | logger.info(
|
| | f"\nThought #{i + 1}: {thoughts}\nAction #{i + 1}: {action}"
|
| | )
|
| |
|
| |
|
| | try:
|
| | (
|
| | is_stop_action,
|
| | new_screenshot,
|
| | action_description,
|
| | ) = await agent.execute_action(function_call)
|
| | except TargetClosedError as e:
|
| | logger.warning(
|
| | "Page closed during action execution, attempting recovery..."
|
| | )
|
| |
|
| | new_page = self._get_active_page()
|
| | if new_page and new_page != agent._page:
|
| | logger.info("Recovered with new active page")
|
| | agent._page = new_page
|
| |
|
| | await asyncio.sleep(1)
|
| | action_description = (
|
| | "Action completed (page navigation occurred)"
|
| | )
|
| | is_stop_action = False
|
| | new_screenshot = None
|
| | else:
|
| | raise e
|
| |
|
| |
|
| | active_page = self._get_active_page()
|
| | if active_page and active_page != agent._page:
|
| | logger.info("Updating agent page reference to active page")
|
| | agent._page = active_page
|
| |
|
| |
|
| | screenshot_base64 = await self.get_screenshot_base64()
|
| |
|
| | except TargetClosedError as e:
|
| | logger.error(f"Unrecoverable page error: {e}")
|
| | await self.send_event(
|
| | {
|
| | "type": "agent_error",
|
| | "error": f"Browser page closed unexpectedly: {str(e)}",
|
| | }
|
| | )
|
| | return
|
| | except Exception as e:
|
| | logger.exception(f"Error in agent step {i + 1}")
|
| | await self.send_event({"type": "agent_error", "error": str(e)})
|
| | return
|
| |
|
| |
|
| | step_duration = (datetime.now() - step_start_time).total_seconds()
|
| | step_input_tokens = 1000
|
| | step_output_tokens = len(raw_response) // 4
|
| |
|
| | self.total_input_tokens += step_input_tokens
|
| | self.total_output_tokens += step_output_tokens
|
| | self.step_count += 1
|
| |
|
| |
|
| | step = {
|
| | "stepId": str(uuid.uuid4()),
|
| | "traceId": self.trace_id,
|
| | "stepNumber": self.step_count,
|
| | "thought": thoughts,
|
| | "actions": [
|
| | {
|
| | "function_name": action,
|
| | "description": action_description,
|
| | "parameters": action_args,
|
| | }
|
| | ],
|
| | "image": screenshot_base64,
|
| | "duration": step_duration,
|
| | "inputTokensUsed": step_input_tokens,
|
| | "outputTokensUsed": step_output_tokens,
|
| | "timestamp": datetime.now().isoformat(),
|
| | }
|
| |
|
| |
|
| | await self.send_event(
|
| | {
|
| | "type": "agent_progress",
|
| | "agentStep": step,
|
| | "traceMetadata": self._get_metadata(),
|
| | }
|
| | )
|
| |
|
| | if is_stop_action:
|
| | final_answer = thoughts
|
| | break
|
| |
|
| |
|
| | final_state = "success" if is_stop_action else "max_steps_reached"
|
| | await self.send_event(
|
| | {
|
| | "type": "agent_complete",
|
| | "traceMetadata": self._get_metadata(completed=True),
|
| | "final_state": final_state,
|
| | }
|
| | )
|
| |
|
| | def _get_metadata(self, completed: bool = False) -> dict:
|
| | """Get current trace metadata"""
|
| | duration = 0
|
| | if self.start_time:
|
| | duration = (datetime.now() - self.start_time).total_seconds()
|
| |
|
| | return {
|
| | "traceId": self.trace_id,
|
| | "inputTokensUsed": self.total_input_tokens,
|
| | "outputTokensUsed": self.total_output_tokens,
|
| | "duration": duration,
|
| | "numberOfSteps": self.step_count,
|
| | "maxSteps": 50,
|
| | "completed": completed,
|
| | }
|
| |
|
| | async def stop(self):
|
| | """Request the agent to stop"""
|
| | self.should_stop = True
|
| |
|
| | async def close(self):
|
| | """Clean up resources"""
|
| | if self.agent:
|
| | try:
|
| | await self.agent.close()
|
| | except Exception as e:
|
| | logger.error(f"Error closing agent: {e}")
|
| | self.agent = None
|
| | self.browser_manager = None
|
| |
|
| | if self.screenshots_dir and os.path.exists(self.screenshots_dir):
|
| | import shutil
|
| |
|
| | try:
|
| | shutil.rmtree(self.screenshots_dir)
|
| | except Exception as e:
|
| | logger.error(f"Error cleaning up screenshots: {e}")
|
| | self.screenshots_dir = None
|
| |
|
| |
|
| | @app.get("/api/models")
|
| | async def get_models():
|
| | """Return available models"""
|
| | return JSONResponse(content=AVAILABLE_MODELS)
|
| |
|
| |
|
| | @app.post("/api/traces")
|
| | async def store_trace(trace_data: dict):
|
| | """
|
| | Store a task trace by forwarding to the Modal trace storage endpoint.
|
| | This keeps Modal credentials on the server side.
|
| | """
|
| | if not MODAL_TRACE_STORAGE_URL:
|
| | logger.warning("Modal trace storage URL not configured")
|
| | return JSONResponse(
|
| | status_code=503,
|
| | content={"success": False, "error": "Trace storage not configured"},
|
| | )
|
| |
|
| | if not MODAL_TOKEN_ID or not MODAL_TOKEN_SECRET:
|
| | logger.warning("Modal proxy auth credentials not configured")
|
| | return JSONResponse(
|
| | status_code=503,
|
| | content={"success": False, "error": "Modal auth not configured"},
|
| | )
|
| |
|
| | try:
|
| | async with httpx.AsyncClient(timeout=30.0) as client:
|
| | response = await client.post(
|
| | MODAL_TRACE_STORAGE_URL,
|
| | json=trace_data,
|
| | headers={
|
| | "Content-Type": "application/json",
|
| | "Modal-Key": MODAL_TOKEN_ID,
|
| | "Modal-Secret": MODAL_TOKEN_SECRET,
|
| | },
|
| | )
|
| |
|
| | if response.status_code == 200:
|
| | result = response.json()
|
| | logger.info(
|
| | f"Trace stored successfully: {result.get('trace_id', 'unknown')}"
|
| | )
|
| | return JSONResponse(content=result)
|
| | else:
|
| | error_text = response.text
|
| | logger.error(
|
| | f"Failed to store trace: {response.status_code} - {error_text}"
|
| | )
|
| | return JSONResponse(
|
| | status_code=response.status_code,
|
| | content={
|
| | "success": False,
|
| | "error": f"Modal API error: {error_text}",
|
| | },
|
| | )
|
| | except httpx.TimeoutException:
|
| | logger.error("Timeout storing trace to Modal")
|
| | return JSONResponse(
|
| | status_code=504,
|
| | content={"success": False, "error": "Timeout connecting to trace storage"},
|
| | )
|
| | except Exception as e:
|
| | logger.exception("Error storing trace")
|
| | return JSONResponse(
|
| | status_code=500, content={"success": False, "error": str(e)}
|
| | )
|
| |
|
| |
|
| | @app.get("/api/random-question")
|
| | async def get_random_question():
|
| | """Return a random example question"""
|
| | questions = [
|
| | "Search for the latest news about AI agents",
|
| | "Find the weather forecast for San Francisco",
|
| | "Go to GitHub and search for 'computer use agent'",
|
| | "Find the top trending repositories on GitHub today",
|
| | "Search for Python tutorials on YouTube",
|
| | "Look up the current stock price of Microsoft",
|
| | "Find the schedule for upcoming SpaceX launches",
|
| | "Search for healthy breakfast recipes",
|
| | ]
|
| | import random
|
| |
|
| | return JSONResponse(content={"question": random.choice(questions)})
|
| |
|
| |
|
| | @app.websocket("/ws")
|
| | async def websocket_endpoint(websocket: WebSocket):
|
| | """WebSocket endpoint for real-time communication"""
|
| | await websocket.accept()
|
| |
|
| |
|
| | connection_id = str(uuid.uuid4())
|
| | active_connections[connection_id] = websocket
|
| |
|
| |
|
| | trace_id = str(uuid.uuid4())
|
| | await websocket.send_json(
|
| | {"type": "heartbeat", "uuid": trace_id, "timestamp": datetime.now().isoformat()}
|
| | )
|
| |
|
| | try:
|
| | while True:
|
| |
|
| | data = await websocket.receive_json()
|
| | message_type = data.get("type")
|
| |
|
| | if message_type == "user_task":
|
| |
|
| | trace = data.get("trace", {})
|
| | trace_id = trace.get("id", str(uuid.uuid4()))
|
| | instruction = trace.get("instruction", "")
|
| | model_id = trace.get("modelId", "microsoft/Fara-7B")
|
| |
|
| |
|
| | session = FaraSession(trace_id, websocket)
|
| | active_sessions[trace_id] = session
|
| |
|
| |
|
| | asyncio.create_task(session.run_task(instruction, model_id))
|
| |
|
| | elif message_type == "stop_task":
|
| |
|
| | trace_id = data.get("trace_id")
|
| | if trace_id and trace_id in active_sessions:
|
| | await active_sessions[trace_id].stop()
|
| |
|
| | elif message_type == "ping":
|
| | await websocket.send_json({"type": "pong"})
|
| |
|
| | except WebSocketDisconnect:
|
| | logger.info(f"WebSocket disconnected: {connection_id}")
|
| | except Exception as e:
|
| | logger.exception(f"WebSocket error: {e}")
|
| | finally:
|
| |
|
| | if connection_id in active_connections:
|
| | del active_connections[connection_id]
|
| |
|
| |
|
| | sessions_to_remove = []
|
| | for trace_id, session in active_sessions.items():
|
| | if session.websocket == websocket:
|
| | await session.close()
|
| | sessions_to_remove.append(trace_id)
|
| | for trace_id in sessions_to_remove:
|
| | del active_sessions[trace_id]
|
| |
|
| |
|
| | @app.get("/api/health")
|
| | async def health_check():
|
| | """Health check endpoint"""
|
| | return {"status": "healthy"}
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | import uvicorn
|
| |
|
| | uvicorn.run(app, host="0.0.0.0", port=8000)
|
| |
|