Spaces:
Sleeping
Sleeping
| import os, json, time, uuid, threading, gc | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, JSONResponse, Response | |
| import uvicorn | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=False, | |
| allow_methods=["GET","POST","OPTIONS","HEAD"], | |
| allow_headers=["*"], expose_headers=["*"], max_age=86400, | |
| ) | |
| CORS_HEADERS = { | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "GET, POST, OPTIONS, HEAD", | |
| "Access-Control-Allow-Headers": "Content-Type, Authorization, X-Requested-With", | |
| "Access-Control-Max-Age": "86400", | |
| } | |
| async def preflight(path: str): | |
| return Response(status_code=204, headers=CORS_HEADERS) | |
| llm = None | |
| status = "starting" | |
| MODEL_REPO = "mradermacher/DarkGPT-model-GGUF" | |
| MODEL_FILENAME = "DarkGPT-model.Q2_K.gguf" | |
| MODEL_PATH = "/app/models/DarkGPT-model.Q2_K.gguf" | |
| MODEL_ID = "DarkGPT-model" | |
| # Tokens d'arrêt natifs du modèle | |
| STOP_TOKENS = [ | |
| "<|sim_e||>", "<|sim_s||>", | |
| "<|im_end|>", "<|im_start|>", | |
| "<|end|>", "<|eot_id|>", | |
| "</s>", "<|endoftext|>", | |
| "\n<|", | |
| ] | |
| def build_prompt(messages): | |
| """ | |
| Format natif du modèle DarkGPT : | |
| <|sim_s||>role\ncontent<|sim_e||> | |
| """ | |
| prompt = "" | |
| for m in messages: | |
| role = m.get("role", "user") | |
| content = m.get("content", "").strip() | |
| prompt += f"<|sim_s||>{role}\n{content}<|sim_e||>\n" | |
| # Ouvre le tour assistant pour que le modèle complète | |
| prompt += "<|sim_s||>assistant\n" | |
| return prompt | |
| def init(): | |
| global llm, status | |
| try: | |
| gc.collect() | |
| from huggingface_hub import hf_hub_download | |
| os.makedirs("/app/models", exist_ok=True) | |
| if not os.path.exists(MODEL_PATH): | |
| status = "downloading" | |
| print(f"Downloading {MODEL_FILENAME}...") | |
| hf_hub_download( | |
| repo_id=MODEL_REPO, filename=MODEL_FILENAME, | |
| local_dir="/app/models", local_dir_use_symlinks=False, | |
| ) | |
| print("Download done.") | |
| gc.collect() | |
| status = "loading" | |
| from llama_cpp import Llama | |
| llm = Llama( | |
| model_path = MODEL_PATH, | |
| n_ctx = 512, | |
| n_threads = 2, | |
| n_batch = 128, | |
| verbose = False, | |
| use_mmap = True, | |
| use_mlock = False, | |
| low_vram = True, | |
| # PAS de chat_format — on gère le prompt manuellement | |
| ) | |
| gc.collect() | |
| status = "ready" | |
| print("DarkGPT ready!") | |
| except Exception as e: | |
| status = f"error: {e}" | |
| print(f"INIT ERROR: {e}") | |
| threading.Thread(target=init, daemon=True).start() | |
| def root(): | |
| return {"status": status, "server": "DarkGPT", "ready": status == "ready"} | |
| def root_head(): | |
| return Response(status_code=200) | |
| def health(): | |
| return {"status": status} | |
| def list_models(): | |
| return {"object": "list", "data": [{ | |
| "id": MODEL_ID, "object": "model", | |
| "owned_by": "darkgpt", "created": int(time.time()) | |
| }]} | |
| async def chat(req: Request): | |
| if status != "ready": | |
| msgs_map = { | |
| "downloading": "⏳ Téléchargement du modèle (2-4min)...", | |
| "loading": "⚙️ Chargement du modèle...", | |
| } | |
| msg = msgs_map.get(status, f"Serveur indisponible: {status}") | |
| return JSONResponse( | |
| {"error": {"message": msg, "status": status}}, | |
| status_code=503, headers=CORS_HEADERS, | |
| ) | |
| body = await req.json() | |
| messages = body.get("messages", []) | |
| # Garder system + 6 derniers messages max | |
| sys_msgs = [m for m in messages if m.get("role") == "system"] | |
| other_msgs = [m for m in messages if m.get("role") != "system"] | |
| if len(other_msgs) > 6: | |
| other_msgs = other_msgs[-6:] | |
| messages = sys_msgs + other_msgs | |
| prompt = build_prompt(messages) | |
| client_stops = body.get("stop", []) | |
| if isinstance(client_stops, str): client_stops = [client_stops] | |
| stop_tokens = list(dict.fromkeys(STOP_TOKENS + client_stops)) | |
| gen_kwargs = dict( | |
| prompt = prompt, | |
| max_tokens = min(body.get("max_tokens", 400), 500), | |
| temperature = max(0.1, min(body.get("temperature", 0.5), 0.9)), | |
| top_p = body.get("top_p", 0.85), | |
| top_k = body.get("top_k", 40), | |
| repeat_penalty = max(body.get("repeat_penalty", 1.2), 1.1), | |
| stop = stop_tokens, | |
| ) | |
| def make_chunk(content, finish=None): | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": MODEL_ID, | |
| "choices": [{"delta": {"content": content}, "index": 0, "finish_reason": finish}] | |
| } | |
| if body.get("stream", False): | |
| def gen(): | |
| try: | |
| prev = ""; rep = 0 | |
| for chunk in llm(stream=True, **gen_kwargs): | |
| content = chunk["choices"][0].get("text", "") | |
| if not content: | |
| continue | |
| if content == prev: | |
| rep += 1 | |
| if rep > 3: | |
| yield "data: [DONE]\n\n"; return | |
| else: | |
| rep = 0 | |
| prev = content | |
| yield f"data: {json.dumps(make_chunk(content))}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| finally: | |
| gc.collect() | |
| return StreamingResponse( | |
| gen(), media_type="text/event-stream", | |
| headers={**CORS_HEADERS, "Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| else: | |
| try: | |
| result = llm(**gen_kwargs) | |
| text = result["choices"][0].get("text", "") | |
| # Formater en réponse OpenAI compatible | |
| resp = { | |
| "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": MODEL_ID, | |
| "choices": [{"message": {"role": "assistant", "content": text}, "index": 0, "finish_reason": "stop"}], | |
| "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | |
| } | |
| return JSONResponse(resp, headers=CORS_HEADERS) | |
| finally: | |
| gc.collect() | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") | |