| | """ |
| | Routes for the Image Similarity Search API |
| | Contains all endpoints for the application using your original route implementation |
| | """ |
| |
|
| | import uuid |
| | import base64 |
| | import io |
| | from typing import List, Optional |
| | from fastapi import APIRouter, FastAPI, File, UploadFile, Form, Query, Path |
| | from pydantic import BaseModel |
| | from PIL import Image |
| |
|
| | from services.embedding_service import ImageEmbeddingModel |
| | from services.vector_db_service import VectorDatabaseClient |
| |
|
| |
|
| | class Base64ImageRequest(BaseModel): |
| | """Request model for base64 encoded images""" |
| | image_data: str |
| |
|
| |
|
| | def register_routes( |
| | app: FastAPI, |
| | embedding_model: ImageEmbeddingModel, |
| | vector_db: VectorDatabaseClient, |
| | ): |
| | """Register all routes with the FastAPI app""" |
| |
|
| | @app.api_route("/", methods=["GET", "HEAD"]) |
| | async def read_root(): |
| | return {"status": "API running"} |
| | |
| | @app.post("/add-image/") |
| | async def add_image( |
| | file: UploadFile = File(...), |
| | item_name: str = Form(...), |
| | design_name: str = Form(...), |
| | item_price: float = Form(...) |
| | ): |
| | """Upload an image with product details and store its embedding""" |
| | |
| | |
| | embedding = await embedding_model.get_embedding_from_upload(file) |
| | |
| | |
| | image_id = str(uuid.uuid4()) |
| | |
| | |
| | payload = { |
| | "filename": file.filename, |
| | "item_name": item_name, |
| | "design_name": design_name, |
| | "item_price": item_price |
| | } |
| | |
| | |
| | vector_db.add_image(image_id, embedding, payload) |
| | |
| | return {"message": "Image added successfully", "id": image_id} |
| | |
| | @app.post("/add-images-from-folder/") |
| | async def add_images_from_folder(folder_path: str): |
| | """Process and add all images from a specified folder""" |
| | embeddings = embedding_model.get_embeddings_from_folder(folder_path) |
| | return {"embeddings": embeddings} |
| | |
| | @app.post("/search-by-image/") |
| | async def search_by_image(file: UploadFile = File(...)): |
| | """Search for similar images by uploading a file""" |
| | |
| | |
| | embedding = await embedding_model.get_embedding_from_upload(file) |
| | |
| | |
| | results = vector_db.search_by_vector(embedding, limit=1) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return results |
| | |
| | @app.post("/search-by-image-scan/") |
| | async def search_by_image_scan(request: Base64ImageRequest): |
| | """Search for similar images using a base64 encoded image""" |
| | |
| | image_data = request.image_data |
| | image_bytes = base64.b64decode(image_data.split(',')[1] if ',' in image_data else image_data) |
| | |
| | |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | |
| | |
| | embedding = embedding_model.get_embedding_from_pil(image) |
| | |
| | |
| | results = vector_db.search_by_vector(embedding, limit=1) |
| | |
| | return results |
| | |
| | @app.get("/collections") |
| | def list_collections(): |
| | """List all available collections in the vector database""" |
| | return vector_db.list_collections() |