codi-trace / code /eval /eval_cruxeval_sft.py
sirui6011's picture
add code/ loader snapshot
aedd6ab verified
Raw
History Blame Contribute Delete
5.05 kB
"""Stage 1 baseline eval: CRUXEval-O output prediction via full-trace generation.
Feed the training prompt (seeds frame 0), let the SFT model generate the trace,
take main()'s last return value as the predicted output, score by execution.
Greedy => pass@1 is the exact-match fraction. Reuses cwm_andre eval logic.
"""
import argparse
import json
import os
import subprocess
import sys
from datetime import timedelta
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer
from data.dataset import _prompt_str
from data.sources import load_cruxeval
from tokens import add_trace_tokens, token_ids
ARG_SEP, FRAME_SEP, RETURN_SEP = "<|arg_sep|>", "<|frame_sep|>", "<|return_sep|>"
def extract_answer_trace_full(gen: str) -> str | None:
"""Value of main()'s last RETURN frame: ...<|arg_sep|>"value"<|frame_sep|>."""
r = gen.rfind(RETURN_SEP)
if r == -1:
return None
a = gen.find(ARG_SEP, r)
if a == -1:
return None
rest = gen[a + len(ARG_SEP):]
end = rest.find(FRAME_SEP)
val = (rest[:end] if end != -1 else rest).strip()
if not val:
return None
try:
return json.loads(val)
except json.JSONDecodeError:
return val
def check_correct(code: str, expected: str, predicted: str, timeout: float = 3.0) -> bool:
"""Execute `code; assert expected == predicted` (CRUXEval semantics)."""
test = f"{code}\nassert {expected} == {predicted}"
try:
return subprocess.run(
[sys.executable, "-c", test], timeout=timeout, capture_output=True
).returncode == 0
except Exception:
return False
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True)
ap.add_argument("--n_samples", type=int, default=-1)
ap.add_argument("--max_new_tokens", type=int, default=8192)
ap.add_argument("--batch_size", type=int, default=8)
ap.add_argument("--out", default="")
args = ap.parse_args()
# DDP-style data parallelism for inference: torchrun sets RANK/WORLD_SIZE/LOCAL_RANK.
ddp = "RANK" in os.environ
rank = int(os.environ.get("RANK", 0))
world = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if ddp:
dist.init_process_group("nccl", timeout=timedelta(hours=1)) # ranks finish at different times under long gens
torch.cuda.set_device(local_rank)
tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
add_trace_tokens(tok) # idempotent; ensures trace tokens present
tok.padding_side = "left" # left-pad so all generated tokens start at the same offset
eot_id = token_ids(tok)["<|end_of_text|>"]
model = AutoModelForCausalLM.from_pretrained(
args.model, torch_dtype=torch.bfloat16).to(local_rank).eval()
rows = load_cruxeval()
if args.n_samples > 0:
rows = rows[: args.n_samples]
n = len(rows)
shard = rows[rank::world] # disjoint round-robin split across ranks
n_correct = n_fmt = 0
results = []
for bi, batch_start in enumerate(range(0, len(shard), args.batch_size)):
batch = shard[batch_start: batch_start + args.batch_size]
enc = tok([_prompt_str(r["code"], r["input"]) for r in batch],
return_tensors="pt", padding=True, add_special_tokens=False).to(local_rank)
with torch.no_grad():
out = model.generate(**enc, max_new_tokens=args.max_new_tokens, do_sample=False,
eos_token_id=eot_id, pad_token_id=eot_id)
for j, r in enumerate(batch):
gen = tok.decode(out[j, enc["input_ids"].shape[1]:], skip_special_tokens=False)
pred = extract_answer_trace_full(gen)
ok = pred is not None and check_correct(r["code"], r["output"], pred)
n_fmt += pred is not None
n_correct += ok
results.append({"id": r["id"], "expected": r["output"], "predicted": pred, "correct": ok, "generation": gen})
if rank == 0 and (bi + 1) % 5 == 0:
done = batch_start + len(batch)
print(f" rank0 {done}/{len(shard)} pass@1={n_correct/done:.4f}", flush=True)
# Reduce metrics and gather per-row results across ranks.
if ddp:
t = torch.tensor([n_correct, n_fmt], device=local_rank)
dist.all_reduce(t)
n_correct, n_fmt = int(t[0]), int(t[1])
gathered = [None] * world
dist.gather_object(results, gathered if rank == 0 else None, dst=0)
if rank == 0:
results = [x for part in gathered for x in part]
if rank == 0:
print(f"\nCRUXEval-O pass@1={n_correct / n:.4f} "
f"valid_format={n_fmt / n:.4f} (n={n}, greedy)")
if args.out:
with open(args.out, "w") as f:
json.dump({"pass_at_1": n_correct / n, "valid_format": n_fmt / n,
"n": n, "results": results}, f, indent=2)
if ddp:
dist.destroy_process_group()
if __name__ == "__main__":
main()