Spaces:
Runtime error
Runtime error
| """ | |
| UncGPT-69 hybrid (mamba+MoE+BitNet) — Gradio chat demo on a HF Space GPU. | |
| Loads the step-2000 stage-1-16K checkpoint from the private model repo | |
| Reza2kn/uncgpt-69-hybrid-stage1-step2000, runs inference on the Space's | |
| GPU (CUDA + triton-kernel chunked SSD for Mamba2). | |
| Boot order: | |
| 1. Install mamba-ssm + causal-conv1d (skip CUDA build path; triton kernels | |
| are JIT-compiled at first call so don't need a matching nvcc). | |
| 2. Pull model snapshot from HF. | |
| 3. Import model code (mamba_ssm import is now satisfied). | |
| 4. Launch Gradio. | |
| """ | |
| import os | |
| import subprocess | |
| import sys | |
| import time | |
| from pathlib import Path | |
| # ----- 1. install triton-only mamba-ssm path BEFORE importing the model ----- | |
| def _pip_install(*pkgs, env=None): | |
| full_env = os.environ.copy() | |
| if env: | |
| full_env.update(env) | |
| cmd = [sys.executable, "-m", "pip", "install", "--no-cache-dir", *pkgs] | |
| print(f"$ {' '.join(cmd)}") | |
| subprocess.check_call(cmd, env=full_env) | |
| def _ensure_mamba_ssm(): | |
| try: | |
| from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined # noqa | |
| print("mamba_ssm already importable, skipping install") | |
| return | |
| except Exception as e: | |
| print(f"mamba_ssm not yet importable ({e!r}); installing") | |
| # build deps | |
| _pip_install("packaging", "ninja", "triton") | |
| # causal-conv1d (Python-only; we don't actually use the CUDA conv kernel) | |
| _pip_install("--no-build-isolation", "causal-conv1d", | |
| env={"CAUSAL_CONV1D_SKIP_CUDA_BUILD": "TRUE"}) | |
| # mamba-ssm with skip-cuda-build → triton kernels only | |
| _pip_install("--no-build-isolation", "mamba-ssm", | |
| env={"MAMBA_SKIP_CUDA_BUILD": "TRUE"}) | |
| # transformers needs hf_hub<1.0; pin to 0.36.2 like the cluster did | |
| _pip_install("huggingface_hub==0.36.2") | |
| # smoke | |
| from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined # noqa | |
| print("mamba_ssm installed and importable") | |
| _ensure_mamba_ssm() | |
| # ----- 2. pull model snapshot ----- | |
| import torch | |
| import yaml | |
| import sentencepiece as spm | |
| import gradio as gr | |
| import spaces # ZeroGPU decorator | |
| from huggingface_hub import snapshot_download | |
| MODEL_REPO = "Reza2kn/uncgpt-69-hybrid-stage1-step16000" | |
| WORKSPACE = Path("/tmp/uncgpt-space") | |
| WORKSPACE.mkdir(exist_ok=True, parents=True) | |
| print("=== Pulling model repo from HF ===") | |
| local_dir = snapshot_download( | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| local_dir=str(WORKSPACE), | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| print(f" -> {local_dir}") | |
| # ----- 3. import model + load checkpoint ----- | |
| sys.path.insert(0, str(WORKSPACE)) | |
| from model.unc_gpt import UncGPTConfig # noqa: E402 | |
| from model import build_model # noqa: E402 | |
| cfg_path = WORKSPACE / "configs/stage_01_16k.yaml" | |
| mcfg = UncGPTConfig.from_yaml(cfg_path) | |
| print(f"arch={mcfg.architecture} d={mcfg.d_model} L={mcfg.n_layers} vocab={mcfg.vocab_size} ctx={mcfg.ctx_train}") | |
| # On ZeroGPU the GPU isn't available at import-time — we load on CPU and move | |
| # the model inside @spaces.GPU on first call. (Loading is fast since the | |
| # checkpoint is bf16-friendly.) | |
| print("loading checkpoint to CPU...") | |
| model = build_model(mcfg, ctx=mcfg.ctx_train) | |
| sd = torch.load(WORKSPACE / "uncgpt69_step2000.pt", map_location="cpu") | |
| state = sd.get("model", sd) | |
| state = {k.replace("module.", ""): v for k, v in state.items()} | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| print(f" missing={len(missing)} unexpected={len(unexpected)}") | |
| model.eval() | |
| _GPU_READY = False # flips to True after first @spaces.GPU call moves model | |
| sp = spm.SentencePieceProcessor() | |
| sp.load(str(WORKSPACE / "tokenizer/uncgpt_v6_6069.model")) | |
| print(f"tokenizer vocab={sp.get_piece_size()}") | |
| # ----- 4. inference + gradio (ZeroGPU-decorated) ----- | |
| def _sample(logits, temperature, top_k): | |
| if temperature <= 0.001: | |
| return int(torch.argmax(logits).item()) | |
| scaled = logits / max(temperature, 1e-4) | |
| if top_k > 0 and top_k < scaled.size(-1): | |
| topk_vals, topk_idx = torch.topk(scaled, top_k) | |
| probs = torch.softmax(topk_vals, dim=-1) | |
| pick = torch.multinomial(probs, 1).item() | |
| return int(topk_idx[pick].item()) | |
| probs = torch.softmax(scaled, dim=-1) | |
| return int(torch.multinomial(probs, 1).item()) | |
| def _generate_gpu(prompt: str, max_new: int, temperature: float, top_k: int): | |
| """No-cache streaming generation. Each step re-encodes the entire prefix | |
| (O(L) per token, O(L²) total). Slower than cached but COHERENT — the | |
| cache step path has a state-drift bug (mamba2.step() vs prefill diverge) | |
| that we'll fix in a follow-up. For short demo prompts (~50 tok) producing | |
| ~80 tok responses, this is a couple of seconds per response on H200. | |
| """ | |
| global model, _GPU_READY | |
| if not _GPU_READY: | |
| print("moving model to cuda + bf16...") | |
| model = model.to("cuda").to(torch.bfloat16) | |
| _GPU_READY = True | |
| ids = [sp.bos_id()] + sp.encode(prompt) | |
| if len(ids) >= mcfg.ctx_train - 4: | |
| yield "[prompt too long for ctx_train]", 0 | |
| return | |
| out_ids: list[int] = [] | |
| t0 = time.time() | |
| with torch.no_grad(): | |
| for step in range(max_new): | |
| x = torch.tensor([ids], dtype=torch.long, device="cuda") | |
| last_logits = model(x)["logits"][0, -1].float() | |
| nxt = _sample(last_logits, temperature, top_k) | |
| if nxt == sp.eos_id(): | |
| break | |
| out_ids.append(nxt) | |
| ids.append(nxt) | |
| if step % 2 == 0 or step == max_new - 1: | |
| text = sp.decode(out_ids) | |
| elapsed = max(time.time() - t0, 1e-3) | |
| yield text, len(out_ids) / elapsed | |
| if len(ids) >= mcfg.ctx_train - 4: | |
| break | |
| if out_ids: | |
| text = sp.decode(out_ids) | |
| elapsed = max(time.time() - t0, 1e-3) | |
| yield text, len(out_ids) / elapsed | |
| def generate(prompt: str, max_new: int, temperature: float, top_k: int): | |
| """Thin pass-through so gradio doesn't see the @spaces.GPU decoration directly.""" | |
| yield from _generate_gpu(prompt, max_new, temperature, top_k) | |
| TEMPLATE = "<user>{user}</user><uncle>" | |
| def chat_fn(user_msg, history, max_new, temperature, top_k): | |
| prompt = TEMPLATE.format(user=user_msg.strip()) | |
| history = history or [] | |
| history.append([user_msg, ""]) | |
| for partial, tps in generate(prompt, int(max_new), float(temperature), int(top_k)): | |
| history[-1][1] = partial + f"\n\n_{tps:.1f} tok/s_" | |
| yield history, history | |
| examples = [ | |
| "Hey Uncle, can you help me find a pharmacy near me?", | |
| "I need to email Maria about Mom's appointment on Friday.", | |
| "What's the weather looking like for this afternoon?", | |
| "Can you set up a calendar event for the doctor visit Tuesday at 3pm?", | |
| "I'm feeling overwhelmed today.", | |
| ] | |
| with gr.Blocks(title="UncGPT-69 hybrid (stage 1, step 2000)") as demo: | |
| gr.Markdown( | |
| "# UncGPT-69 hybrid — stage 1 16K, step 16000\n" | |
| "47M-param Jamba-style hybrid (10 Mamba-2 + 2 MQA) + MoE (1+6 experts, top-2) + BitNet 1.58b weights, " | |
| "vocab 6069, ctx_train 16384. Trained 16k steps on 8×A100 + 4×L40, final lm_loss 1.26 (held-out PPL 4.87). \n" | |
| "Per-utterance template: `<user>...</user><uncle>...</uncle>`. \n" | |
| "_Note: running in no-cache mode (re-encode each step) due to a known step-state drift bug in the mamba/KV cache path. Slower but correct._" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot(height=420, show_copy_button=True) | |
| msg = gr.Textbox(label="Your message", placeholder="Hey Uncle, ...") | |
| with gr.Row(): | |
| send = gr.Button("Send", variant="primary") | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=1): | |
| max_new = gr.Slider(8, 256, value=80, step=8, label="max new tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature (0 = greedy)") | |
| top_k = gr.Slider(0, 200, value=40, step=5, label="top_k (0 = none)") | |
| gr.Examples(examples, inputs=msg, label="Try one") | |
| send.click(chat_fn, [msg, chatbot, max_new, temperature, top_k], [chatbot, chatbot]).then(lambda: "", outputs=msg) | |
| msg.submit(chat_fn, [msg, chatbot, max_new, temperature, top_k], [chatbot, chatbot]).then(lambda: "", outputs=msg) | |
| clear.click(lambda: [], outputs=chatbot) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |