Reza2kn's picture
point to step16000 ckpt + no-cache generate (cache step has known state drift bug)
03cf911 verified
"""
UncGPT-69 hybrid (mamba+MoE+BitNet) — Gradio chat demo on a HF Space GPU.
Loads the step-2000 stage-1-16K checkpoint from the private model repo
Reza2kn/uncgpt-69-hybrid-stage1-step2000, runs inference on the Space's
GPU (CUDA + triton-kernel chunked SSD for Mamba2).
Boot order:
1. Install mamba-ssm + causal-conv1d (skip CUDA build path; triton kernels
are JIT-compiled at first call so don't need a matching nvcc).
2. Pull model snapshot from HF.
3. Import model code (mamba_ssm import is now satisfied).
4. Launch Gradio.
"""
import os
import subprocess
import sys
import time
from pathlib import Path
# ----- 1. install triton-only mamba-ssm path BEFORE importing the model -----
def _pip_install(*pkgs, env=None):
full_env = os.environ.copy()
if env:
full_env.update(env)
cmd = [sys.executable, "-m", "pip", "install", "--no-cache-dir", *pkgs]
print(f"$ {' '.join(cmd)}")
subprocess.check_call(cmd, env=full_env)
def _ensure_mamba_ssm():
try:
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined # noqa
print("mamba_ssm already importable, skipping install")
return
except Exception as e:
print(f"mamba_ssm not yet importable ({e!r}); installing")
# build deps
_pip_install("packaging", "ninja", "triton")
# causal-conv1d (Python-only; we don't actually use the CUDA conv kernel)
_pip_install("--no-build-isolation", "causal-conv1d",
env={"CAUSAL_CONV1D_SKIP_CUDA_BUILD": "TRUE"})
# mamba-ssm with skip-cuda-build → triton kernels only
_pip_install("--no-build-isolation", "mamba-ssm",
env={"MAMBA_SKIP_CUDA_BUILD": "TRUE"})
# transformers needs hf_hub<1.0; pin to 0.36.2 like the cluster did
_pip_install("huggingface_hub==0.36.2")
# smoke
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined # noqa
print("mamba_ssm installed and importable")
_ensure_mamba_ssm()
# ----- 2. pull model snapshot -----
import torch
import yaml
import sentencepiece as spm
import gradio as gr
import spaces # ZeroGPU decorator
from huggingface_hub import snapshot_download
MODEL_REPO = "Reza2kn/uncgpt-69-hybrid-stage1-step16000"
WORKSPACE = Path("/tmp/uncgpt-space")
WORKSPACE.mkdir(exist_ok=True, parents=True)
print("=== Pulling model repo from HF ===")
local_dir = snapshot_download(
repo_id=MODEL_REPO,
repo_type="model",
local_dir=str(WORKSPACE),
token=os.environ.get("HF_TOKEN"),
)
print(f" -> {local_dir}")
# ----- 3. import model + load checkpoint -----
sys.path.insert(0, str(WORKSPACE))
from model.unc_gpt import UncGPTConfig # noqa: E402
from model import build_model # noqa: E402
cfg_path = WORKSPACE / "configs/stage_01_16k.yaml"
mcfg = UncGPTConfig.from_yaml(cfg_path)
print(f"arch={mcfg.architecture} d={mcfg.d_model} L={mcfg.n_layers} vocab={mcfg.vocab_size} ctx={mcfg.ctx_train}")
# On ZeroGPU the GPU isn't available at import-time — we load on CPU and move
# the model inside @spaces.GPU on first call. (Loading is fast since the
# checkpoint is bf16-friendly.)
print("loading checkpoint to CPU...")
model = build_model(mcfg, ctx=mcfg.ctx_train)
sd = torch.load(WORKSPACE / "uncgpt69_step2000.pt", map_location="cpu")
state = sd.get("model", sd)
state = {k.replace("module.", ""): v for k, v in state.items()}
missing, unexpected = model.load_state_dict(state, strict=False)
print(f" missing={len(missing)} unexpected={len(unexpected)}")
model.eval()
_GPU_READY = False # flips to True after first @spaces.GPU call moves model
sp = spm.SentencePieceProcessor()
sp.load(str(WORKSPACE / "tokenizer/uncgpt_v6_6069.model"))
print(f"tokenizer vocab={sp.get_piece_size()}")
# ----- 4. inference + gradio (ZeroGPU-decorated) -----
def _sample(logits, temperature, top_k):
if temperature <= 0.001:
return int(torch.argmax(logits).item())
scaled = logits / max(temperature, 1e-4)
if top_k > 0 and top_k < scaled.size(-1):
topk_vals, topk_idx = torch.topk(scaled, top_k)
probs = torch.softmax(topk_vals, dim=-1)
pick = torch.multinomial(probs, 1).item()
return int(topk_idx[pick].item())
probs = torch.softmax(scaled, dim=-1)
return int(torch.multinomial(probs, 1).item())
@spaces.GPU(duration=120)
def _generate_gpu(prompt: str, max_new: int, temperature: float, top_k: int):
"""No-cache streaming generation. Each step re-encodes the entire prefix
(O(L) per token, O(L²) total). Slower than cached but COHERENT — the
cache step path has a state-drift bug (mamba2.step() vs prefill diverge)
that we'll fix in a follow-up. For short demo prompts (~50 tok) producing
~80 tok responses, this is a couple of seconds per response on H200.
"""
global model, _GPU_READY
if not _GPU_READY:
print("moving model to cuda + bf16...")
model = model.to("cuda").to(torch.bfloat16)
_GPU_READY = True
ids = [sp.bos_id()] + sp.encode(prompt)
if len(ids) >= mcfg.ctx_train - 4:
yield "[prompt too long for ctx_train]", 0
return
out_ids: list[int] = []
t0 = time.time()
with torch.no_grad():
for step in range(max_new):
x = torch.tensor([ids], dtype=torch.long, device="cuda")
last_logits = model(x)["logits"][0, -1].float()
nxt = _sample(last_logits, temperature, top_k)
if nxt == sp.eos_id():
break
out_ids.append(nxt)
ids.append(nxt)
if step % 2 == 0 or step == max_new - 1:
text = sp.decode(out_ids)
elapsed = max(time.time() - t0, 1e-3)
yield text, len(out_ids) / elapsed
if len(ids) >= mcfg.ctx_train - 4:
break
if out_ids:
text = sp.decode(out_ids)
elapsed = max(time.time() - t0, 1e-3)
yield text, len(out_ids) / elapsed
def generate(prompt: str, max_new: int, temperature: float, top_k: int):
"""Thin pass-through so gradio doesn't see the @spaces.GPU decoration directly."""
yield from _generate_gpu(prompt, max_new, temperature, top_k)
TEMPLATE = "<user>{user}</user><uncle>"
def chat_fn(user_msg, history, max_new, temperature, top_k):
prompt = TEMPLATE.format(user=user_msg.strip())
history = history or []
history.append([user_msg, ""])
for partial, tps in generate(prompt, int(max_new), float(temperature), int(top_k)):
history[-1][1] = partial + f"\n\n_{tps:.1f} tok/s_"
yield history, history
examples = [
"Hey Uncle, can you help me find a pharmacy near me?",
"I need to email Maria about Mom's appointment on Friday.",
"What's the weather looking like for this afternoon?",
"Can you set up a calendar event for the doctor visit Tuesday at 3pm?",
"I'm feeling overwhelmed today.",
]
with gr.Blocks(title="UncGPT-69 hybrid (stage 1, step 2000)") as demo:
gr.Markdown(
"# UncGPT-69 hybrid — stage 1 16K, step 16000\n"
"47M-param Jamba-style hybrid (10 Mamba-2 + 2 MQA) + MoE (1+6 experts, top-2) + BitNet 1.58b weights, "
"vocab 6069, ctx_train 16384. Trained 16k steps on 8×A100 + 4×L40, final lm_loss 1.26 (held-out PPL 4.87). \n"
"Per-utterance template: `<user>...</user><uncle>...</uncle>`. \n"
"_Note: running in no-cache mode (re-encode each step) due to a known step-state drift bug in the mamba/KV cache path. Slower but correct._"
)
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(height=420, show_copy_button=True)
msg = gr.Textbox(label="Your message", placeholder="Hey Uncle, ...")
with gr.Row():
send = gr.Button("Send", variant="primary")
clear = gr.Button("Clear")
with gr.Column(scale=1):
max_new = gr.Slider(8, 256, value=80, step=8, label="max new tokens")
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature (0 = greedy)")
top_k = gr.Slider(0, 200, value=40, step=5, label="top_k (0 = none)")
gr.Examples(examples, inputs=msg, label="Try one")
send.click(chat_fn, [msg, chatbot, max_new, temperature, top_k], [chatbot, chatbot]).then(lambda: "", outputs=msg)
msg.submit(chat_fn, [msg, chatbot, max_new, temperature, top_k], [chatbot, chatbot]).then(lambda: "", outputs=msg)
clear.click(lambda: [], outputs=chatbot)
if __name__ == "__main__":
demo.queue().launch()