File size: 4,818 Bytes
aedd6ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | """CRUXEval-O latent eval: the CODI student generates the trace, but at every
<|line_sep|> the frame's $LOCALS is replaced by a latent block (latent_start +
latent_steps recurrent latents + latent_end), mirroring training _student.
"""
import argparse
import json
import os
from datetime import timedelta
import torch
import torch.distributed as dist
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from data.dataset import _prompt_str
from data.sources import load_cruxeval
from eval.eval_cruxeval_sft import check_correct, extract_answer_trace_full
from tokens import add_trace_tokens, token_ids
from train.train_codi import CodiModel
def load_codi(m, latent_steps, dev):
tok = AutoTokenizer.from_pretrained(m, use_fast=True)
add_trace_tokens(tok)
ids = token_ids(tok)
base = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(m), torch_dtype=torch.bfloat16)
model = CodiModel(base, latent_start_id=ids["<|latent_start|>"],
latent_end_id=ids["<|latent_end|>"], latent_steps=latent_steps)
if os.path.exists(f"{m}/pytorch_model.bin"): # epoch checkpoint: full CodiModel
model.load_state_dict(torch.load(f"{m}/pytorch_model.bin", map_location="cpu"))
else: # final export: backbone safetensors + separate projector
model.model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.bfloat16)
model.prj.load_state_dict(torch.load(f"{m}/thought_projector.pt", map_location="cpu"))
return tok, ids, model.to(dev).eval()
@torch.no_grad()
def gen_latent(model, prompt_ids, ls_id, eot, max_new):
dev = prompt_ids.device
o = model.model(input_ids=prompt_ids[None], use_cache=True)
cache, logits = o.past_key_values, o.logits[:, -1]
out = []
for _ in range(max_new):
t = int(logits.argmax(-1))
if t == eot:
break
out.append(t)
o = model.model(input_ids=torch.tensor([[t]], device=dev), past_key_values=cache, use_cache=True)
cache = o.past_key_values
if t == ls_id: # drop $LOCALS, insert latent block; its logits predict <|action_sep|>
cache, logits = model._latent_block(cache)
else:
logits = o.logits[:, -1]
return out
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True)
ap.add_argument("--n_samples", type=int, default=-1)
ap.add_argument("--max_new_tokens", type=int, default=8192)
ap.add_argument("--latent_steps", type=int, default=1)
ap.add_argument("--out", default="")
args = ap.parse_args()
ddp = "RANK" in os.environ
rank = int(os.environ.get("RANK", 0))
world = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if ddp:
dist.init_process_group("nccl", timeout=timedelta(hours=1)) # ranks finish at different times under long gens
torch.cuda.set_device(local_rank)
tok, ids, model = load_codi(args.model, args.latent_steps, local_rank)
ls_id, eot = ids["<|line_sep|>"], ids["<|end_of_text|>"]
rows = load_cruxeval()
if args.n_samples > 0:
rows = rows[: args.n_samples]
n = len(rows)
shard = rows[rank::world]
n_correct = n_fmt = 0
results = []
for i, r in enumerate(shard):
enc = tok(_prompt_str(r["code"], r["input"]), return_tensors="pt",
add_special_tokens=False).to(local_rank)
gen = tok.decode(gen_latent(model, enc["input_ids"][0], ls_id, eot, args.max_new_tokens),
skip_special_tokens=False)
pred = extract_answer_trace_full(gen)
ok = pred is not None and check_correct(r["code"], r["output"], pred)
n_fmt += pred is not None
n_correct += ok
results.append({"id": r["id"], "expected": r["output"], "predicted": pred, "correct": ok, "generation": gen})
if rank == 0 and (i + 1) % 20 == 0:
print(f" rank0 {i+1}/{len(shard)} pass@1={n_correct/(i+1):.4f}", flush=True)
if ddp:
t = torch.tensor([n_correct, n_fmt], device=local_rank)
dist.all_reduce(t)
n_correct, n_fmt = int(t[0]), int(t[1])
gathered = [None] * world
dist.gather_object(results, gathered if rank == 0 else None, dst=0)
if rank == 0:
results = [x for part in gathered for x in part]
if rank == 0:
print(f"\nCRUXEval-O latent pass@1={n_correct / n:.4f} "
f"valid_format={n_fmt / n:.4f} (n={n}, greedy)")
if args.out:
with open(args.out, "w") as f:
json.dump({"pass_at_1": n_correct / n, "valid_format": n_fmt / n,
"n": n, "results": results}, f, indent=2)
if ddp:
dist.destroy_process_group()
if __name__ == "__main__":
main()
|