Spaces:
Running
Running
| """Shared runtime utilities for the Modal and HF Jobs training adapters. | |
| Both adapters need to: wait for the env server, optionally spin up a vLLM | |
| subprocess, and resolve resume checkpoints. This module centralises that logic | |
| so neither adapter file duplicates it. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import subprocess | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import requests | |
| def wait_for_env_server(env_url: str, retries: int = 30, delay: int = 2) -> None: | |
| """Poll the VeriRL environment server until its /health endpoint responds. | |
| Args: | |
| env_url: Base URL of the VeriRL environment server. | |
| retries: Maximum number of poll attempts before raising. | |
| delay: Seconds to wait between each attempt. | |
| Raises: | |
| RuntimeError: If the server does not respond within ``retries * delay`` seconds. | |
| """ | |
| print(f"[VeriRL] Waiting for env server at {env_url} ...") | |
| for _ in range(retries): | |
| try: | |
| if requests.get(f"{env_url}/health", timeout=5).status_code == 200: | |
| print("[VeriRL] Env server ready.") | |
| return | |
| except Exception: | |
| pass | |
| time.sleep(delay) | |
| raise RuntimeError( | |
| f"VeriRL env server at {env_url} not reachable after {retries * delay}s" | |
| ) | |
| def set_single_node_dist_env() -> None: | |
| """Set PyTorch distributed env vars for single-node, single-process training. | |
| Must be called before any CUDA context is opened. Configures RANK, | |
| LOCAL_RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT, and | |
| PYTORCH_CUDA_ALLOC_CONF for GRPOTrainer's internal process group. | |
| """ | |
| os.environ.update({ | |
| "RANK": "0", | |
| "LOCAL_RANK": "0", | |
| "WORLD_SIZE": "1", | |
| "MASTER_ADDR": "localhost", | |
| "MASTER_PORT": "12355", | |
| "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", | |
| }) | |
| def latest_checkpoint(root: str | Path) -> str | None: | |
| """Return the path of the highest-numbered ``checkpoint-N`` directory, or None. | |
| Args: | |
| root: Directory to search for ``checkpoint-N`` subdirectories. | |
| Returns: | |
| Absolute path string to the latest checkpoint, or ``None`` if none exist. | |
| """ | |
| root = Path(root) | |
| checkpoints: list[tuple[int, Path]] = [] | |
| for candidate in root.glob("checkpoint-*"): | |
| if not candidate.is_dir(): | |
| continue | |
| try: | |
| step = int(candidate.name.rsplit("-", 1)[1]) | |
| except (IndexError, ValueError): | |
| continue | |
| checkpoints.append((step, candidate)) | |
| if not checkpoints: | |
| return None | |
| return str(max(checkpoints, key=lambda item: item[0])[1]) | |
| def start_vllm_server( | |
| vllm_model: str, | |
| max_model_len: int, | |
| port: int = 8001, | |
| log_path: str = "/tmp/vllm_server.log", | |
| ) -> subprocess.Popen: | |
| """Launch a ``trl vllm-serve`` subprocess on GPU 1 and wait until it is healthy. | |
| Strips PyTorch distributed env vars from the subprocess environment so | |
| vLLM's own ``dist.init_process_group`` does not conflict with the training | |
| TCPStore running at MASTER_PORT. | |
| Args: | |
| vllm_model: HuggingFace model ID or local path for vLLM to serve. | |
| max_model_len: Maximum token sequence length for the KV cache. | |
| port: HTTP port the vLLM server listens on. | |
| log_path: File path for combined vLLM stdout/stderr. | |
| Returns: | |
| The running ``subprocess.Popen`` handle for the vLLM server. | |
| Raises: | |
| RuntimeError: If the process exits early or fails to start within 360 s. | |
| """ | |
| trl_bin = str(Path(sys.executable).parent / "trl") | |
| trl_ver = subprocess.run( | |
| [sys.executable, "-c", "import trl; print(trl.__version__)"], | |
| capture_output=True, | |
| text=True, | |
| ) | |
| print(f"[VeriRL] Starting vLLM server on GPU 1, port {port} ...") | |
| print(f"[VeriRL] trl binary: {trl_bin} version: {trl_ver.stdout.strip()}") | |
| _DIST_KEYS = { | |
| "RANK", "LOCAL_RANK", "WORLD_SIZE", | |
| "MASTER_ADDR", "MASTER_PORT", | |
| "TORCHELASTIC_RESTART_COUNT", "TORCHELASTIC_MAX_RESTARTS", | |
| } | |
| vllm_env = {k: v for k, v in os.environ.items() if k not in _DIST_KEYS} | |
| vllm_env.update({"CUDA_VISIBLE_DEVICES": "1", "PYTHONUNBUFFERED": "1"}) | |
| vllm_log = open(log_path, "w") | |
| proc = subprocess.Popen( | |
| [ | |
| trl_bin, "vllm-serve", | |
| "--model", vllm_model, | |
| "--port", str(port), | |
| "--gpu-memory-utilization", "0.9", | |
| "--max-model-len", str(max_model_len), | |
| ], | |
| env=vllm_env, | |
| stdout=vllm_log, | |
| stderr=subprocess.STDOUT, | |
| ) | |
| for i in range(180): # up to 360 s — first run downloads the model | |
| if proc.poll() is not None: | |
| vllm_log.flush() | |
| tail = open(log_path).read()[-3000:] | |
| raise RuntimeError( | |
| f"vLLM server exited early (code {proc.returncode}):\n{tail}" | |
| ) | |
| try: | |
| if requests.get(f"http://localhost:{port}/health", timeout=2).status_code == 200: | |
| print("[VeriRL] vLLM server ready.") | |
| return proc | |
| except Exception: | |
| pass | |
| if i % 30 == 29: | |
| vllm_log.flush() | |
| print(f"[VeriRL] vLLM still starting ({(i + 1) * 2}s) ...") | |
| time.sleep(2) | |
| proc.kill() | |
| tail = open(log_path).read()[-3000:] | |
| raise RuntimeError(f"vLLM server failed to start within 360s. Log:\n{tail}") | |
| def build_vllm_kwargs( | |
| gpu_count: int, | |
| vllm_model: str, | |
| max_model_len: int, | |
| vllm_port: int = 8001, | |
| ) -> dict: | |
| """Build the vLLM configuration kwargs dict for GRPOConfig. | |
| Chooses *server mode* when two or more GPUs are available (vLLM on GPU 1, | |
| training on GPU 0) and *colocate mode* otherwise. In colocate mode the | |
| context window is capped at 8192 to avoid OOM on a single card. | |
| Args: | |
| gpu_count: Number of available CUDA devices (``torch.cuda.device_count()``). | |
| vllm_model: HuggingFace model ID served by vLLM (unused in colocate mode). | |
| max_model_len: Maximum sequence length from the training config. | |
| vllm_port: Port the vLLM server listens on (server mode only). | |
| Returns: | |
| Dict ready to unpack as ``GRPOConfig(**vllm_kwargs)``. | |
| """ | |
| if gpu_count >= 2: | |
| return { | |
| "use_vllm": True, | |
| "vllm_mode": "server", | |
| "vllm_server_host": "localhost", | |
| "vllm_server_port": vllm_port, | |
| "vllm_gpu_memory_utilization": 0.9, | |
| "vllm_max_model_length": max_model_len, | |
| } | |
| return { | |
| "use_vllm": True, | |
| "vllm_mode": "colocate", | |
| "vllm_gpu_memory_utilization": 0.5, | |
| "vllm_max_model_length": min(max_model_len, 8192), | |
| } | |
| def resolve_resume_checkpoint( | |
| output_dir: str | Path, | |
| hub_repo_id: str, | |
| hf_token: str, | |
| ) -> str | None: | |
| """Resolve the VERIRL_RESUME_FROM_CHECKPOINT env var to a local checkpoint path. | |
| Resolution order: | |
| 1. Env var unset → return ``None`` (fresh start). | |
| 2. Env var is an explicit path (not ``'latest'``) → return it directly. | |
| 3. Search ``output_dir`` for the highest-numbered checkpoint. | |
| 4. Download from ``hub_repo_id`` and search the downloaded snapshot. | |
| Args: | |
| output_dir: Local directory where checkpoints are written. | |
| hub_repo_id: HuggingFace Hub repo to download from as a fallback. | |
| hf_token: HuggingFace token for authenticated Hub downloads. | |
| Returns: | |
| Absolute path to the checkpoint directory, or ``None`` for a fresh start. | |
| Raises: | |
| RuntimeError: If the env var is ``'latest'`` but no checkpoint is found. | |
| """ | |
| from huggingface_hub import snapshot_download | |
| requested = os.environ.get("VERIRL_RESUME_FROM_CHECKPOINT", "").strip() | |
| if not requested: | |
| return None | |
| if requested not in {"latest", "last-checkpoint"}: | |
| print(f"[VeriRL] Resuming GRPO from explicit checkpoint: {requested}") | |
| return requested | |
| local_latest = latest_checkpoint(output_dir) | |
| if local_latest: | |
| print(f"[VeriRL] Resuming GRPO from local checkpoint: {local_latest}") | |
| return local_latest | |
| resume_dir = Path(output_dir) / "hub_resume" | |
| print(f"[VeriRL] Downloading checkpoints from {hub_repo_id} ...") | |
| snapshot_download( | |
| repo_id=hub_repo_id, | |
| token=hf_token, | |
| local_dir=resume_dir, | |
| allow_patterns=["last-checkpoint/**", "checkpoint-*/**"], | |
| ) | |
| last_checkpoint = resume_dir / "last-checkpoint" | |
| if last_checkpoint.is_dir(): | |
| print(f"[VeriRL] Resuming GRPO from Hub checkpoint: {last_checkpoint}") | |
| return str(last_checkpoint) | |
| hub_latest = latest_checkpoint(resume_dir) | |
| if hub_latest: | |
| print(f"[VeriRL] Resuming GRPO from Hub checkpoint: {hub_latest}") | |
| return hub_latest | |
| raise RuntimeError( | |
| f"VERIRL_RESUME_FROM_CHECKPOINT={requested!r}, but no checkpoint was found " | |
| f"locally in {output_dir} or on Hub at {hub_repo_id}" | |
| ) | |