verirl-env / training /runtime.py
Supreeth's picture
Upload folder using huggingface_hub
6942c9a verified
"""Shared runtime utilities for the Modal and HF Jobs training adapters.
Both adapters need to: wait for the env server, optionally spin up a vLLM
subprocess, and resolve resume checkpoints. This module centralises that logic
so neither adapter file duplicates it.
"""
from __future__ import annotations
import os
import subprocess
import sys
import time
from pathlib import Path
import requests
def wait_for_env_server(env_url: str, retries: int = 30, delay: int = 2) -> None:
"""Poll the VeriRL environment server until its /health endpoint responds.
Args:
env_url: Base URL of the VeriRL environment server.
retries: Maximum number of poll attempts before raising.
delay: Seconds to wait between each attempt.
Raises:
RuntimeError: If the server does not respond within ``retries * delay`` seconds.
"""
print(f"[VeriRL] Waiting for env server at {env_url} ...")
for _ in range(retries):
try:
if requests.get(f"{env_url}/health", timeout=5).status_code == 200:
print("[VeriRL] Env server ready.")
return
except Exception:
pass
time.sleep(delay)
raise RuntimeError(
f"VeriRL env server at {env_url} not reachable after {retries * delay}s"
)
def set_single_node_dist_env() -> None:
"""Set PyTorch distributed env vars for single-node, single-process training.
Must be called before any CUDA context is opened. Configures RANK,
LOCAL_RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT, and
PYTORCH_CUDA_ALLOC_CONF for GRPOTrainer's internal process group.
"""
os.environ.update({
"RANK": "0",
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12355",
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
})
def latest_checkpoint(root: str | Path) -> str | None:
"""Return the path of the highest-numbered ``checkpoint-N`` directory, or None.
Args:
root: Directory to search for ``checkpoint-N`` subdirectories.
Returns:
Absolute path string to the latest checkpoint, or ``None`` if none exist.
"""
root = Path(root)
checkpoints: list[tuple[int, Path]] = []
for candidate in root.glob("checkpoint-*"):
if not candidate.is_dir():
continue
try:
step = int(candidate.name.rsplit("-", 1)[1])
except (IndexError, ValueError):
continue
checkpoints.append((step, candidate))
if not checkpoints:
return None
return str(max(checkpoints, key=lambda item: item[0])[1])
def start_vllm_server(
vllm_model: str,
max_model_len: int,
port: int = 8001,
log_path: str = "/tmp/vllm_server.log",
) -> subprocess.Popen:
"""Launch a ``trl vllm-serve`` subprocess on GPU 1 and wait until it is healthy.
Strips PyTorch distributed env vars from the subprocess environment so
vLLM's own ``dist.init_process_group`` does not conflict with the training
TCPStore running at MASTER_PORT.
Args:
vllm_model: HuggingFace model ID or local path for vLLM to serve.
max_model_len: Maximum token sequence length for the KV cache.
port: HTTP port the vLLM server listens on.
log_path: File path for combined vLLM stdout/stderr.
Returns:
The running ``subprocess.Popen`` handle for the vLLM server.
Raises:
RuntimeError: If the process exits early or fails to start within 360 s.
"""
trl_bin = str(Path(sys.executable).parent / "trl")
trl_ver = subprocess.run(
[sys.executable, "-c", "import trl; print(trl.__version__)"],
capture_output=True,
text=True,
)
print(f"[VeriRL] Starting vLLM server on GPU 1, port {port} ...")
print(f"[VeriRL] trl binary: {trl_bin} version: {trl_ver.stdout.strip()}")
_DIST_KEYS = {
"RANK", "LOCAL_RANK", "WORLD_SIZE",
"MASTER_ADDR", "MASTER_PORT",
"TORCHELASTIC_RESTART_COUNT", "TORCHELASTIC_MAX_RESTARTS",
}
vllm_env = {k: v for k, v in os.environ.items() if k not in _DIST_KEYS}
vllm_env.update({"CUDA_VISIBLE_DEVICES": "1", "PYTHONUNBUFFERED": "1"})
vllm_log = open(log_path, "w")
proc = subprocess.Popen(
[
trl_bin, "vllm-serve",
"--model", vllm_model,
"--port", str(port),
"--gpu-memory-utilization", "0.9",
"--max-model-len", str(max_model_len),
],
env=vllm_env,
stdout=vllm_log,
stderr=subprocess.STDOUT,
)
for i in range(180): # up to 360 s — first run downloads the model
if proc.poll() is not None:
vllm_log.flush()
tail = open(log_path).read()[-3000:]
raise RuntimeError(
f"vLLM server exited early (code {proc.returncode}):\n{tail}"
)
try:
if requests.get(f"http://localhost:{port}/health", timeout=2).status_code == 200:
print("[VeriRL] vLLM server ready.")
return proc
except Exception:
pass
if i % 30 == 29:
vllm_log.flush()
print(f"[VeriRL] vLLM still starting ({(i + 1) * 2}s) ...")
time.sleep(2)
proc.kill()
tail = open(log_path).read()[-3000:]
raise RuntimeError(f"vLLM server failed to start within 360s. Log:\n{tail}")
def build_vllm_kwargs(
gpu_count: int,
vllm_model: str,
max_model_len: int,
vllm_port: int = 8001,
) -> dict:
"""Build the vLLM configuration kwargs dict for GRPOConfig.
Chooses *server mode* when two or more GPUs are available (vLLM on GPU 1,
training on GPU 0) and *colocate mode* otherwise. In colocate mode the
context window is capped at 8192 to avoid OOM on a single card.
Args:
gpu_count: Number of available CUDA devices (``torch.cuda.device_count()``).
vllm_model: HuggingFace model ID served by vLLM (unused in colocate mode).
max_model_len: Maximum sequence length from the training config.
vllm_port: Port the vLLM server listens on (server mode only).
Returns:
Dict ready to unpack as ``GRPOConfig(**vllm_kwargs)``.
"""
if gpu_count >= 2:
return {
"use_vllm": True,
"vllm_mode": "server",
"vllm_server_host": "localhost",
"vllm_server_port": vllm_port,
"vllm_gpu_memory_utilization": 0.9,
"vllm_max_model_length": max_model_len,
}
return {
"use_vllm": True,
"vllm_mode": "colocate",
"vllm_gpu_memory_utilization": 0.5,
"vllm_max_model_length": min(max_model_len, 8192),
}
def resolve_resume_checkpoint(
output_dir: str | Path,
hub_repo_id: str,
hf_token: str,
) -> str | None:
"""Resolve the VERIRL_RESUME_FROM_CHECKPOINT env var to a local checkpoint path.
Resolution order:
1. Env var unset → return ``None`` (fresh start).
2. Env var is an explicit path (not ``'latest'``) → return it directly.
3. Search ``output_dir`` for the highest-numbered checkpoint.
4. Download from ``hub_repo_id`` and search the downloaded snapshot.
Args:
output_dir: Local directory where checkpoints are written.
hub_repo_id: HuggingFace Hub repo to download from as a fallback.
hf_token: HuggingFace token for authenticated Hub downloads.
Returns:
Absolute path to the checkpoint directory, or ``None`` for a fresh start.
Raises:
RuntimeError: If the env var is ``'latest'`` but no checkpoint is found.
"""
from huggingface_hub import snapshot_download
requested = os.environ.get("VERIRL_RESUME_FROM_CHECKPOINT", "").strip()
if not requested:
return None
if requested not in {"latest", "last-checkpoint"}:
print(f"[VeriRL] Resuming GRPO from explicit checkpoint: {requested}")
return requested
local_latest = latest_checkpoint(output_dir)
if local_latest:
print(f"[VeriRL] Resuming GRPO from local checkpoint: {local_latest}")
return local_latest
resume_dir = Path(output_dir) / "hub_resume"
print(f"[VeriRL] Downloading checkpoints from {hub_repo_id} ...")
snapshot_download(
repo_id=hub_repo_id,
token=hf_token,
local_dir=resume_dir,
allow_patterns=["last-checkpoint/**", "checkpoint-*/**"],
)
last_checkpoint = resume_dir / "last-checkpoint"
if last_checkpoint.is_dir():
print(f"[VeriRL] Resuming GRPO from Hub checkpoint: {last_checkpoint}")
return str(last_checkpoint)
hub_latest = latest_checkpoint(resume_dir)
if hub_latest:
print(f"[VeriRL] Resuming GRPO from Hub checkpoint: {hub_latest}")
return hub_latest
raise RuntimeError(
f"VERIRL_RESUME_FROM_CHECKPOINT={requested!r}, but no checkpoint was found "
f"locally in {output_dir} or on Hub at {hub_repo_id}"
)