codi-trace / code /data /dataset.py
sirui6011's picture
add code/ loader snapshot
aedd6ab verified
Raw
History Blame Contribute Delete
7.27 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
"""
CRUXEval-O dataset: deterministic train/val split + ground-truth execution
traces -> (input_ids, labels) for teacher-forcing CODI.
Neutral data layer shared by training (``cwm.training.data``) and eval
(``evals.cruxeval.run_eval_codi``); depends on nothing in either, so the
split and trace format never drift. Thin HuggingFace-tokenizer wrapper over
the verbatim Table 9 trace generator (``.ground_truth`` / ``.trace_format``):
build the seeded prompt, tokenize ``prompt + render_frames_to_generation(frames)``,
and mask the prompt out of the labels (teacher-forced, so labels == input_ids
with the prompt prefix set to ``-100``).
"""
from __future__ import annotations
from .ground_truth import ground_truth_trace, make_trace_context
from .trace_format import (
ACTION_SEP,
LINE_SEP,
TraceEvent,
render_frames_to_generation,
)
IGNORE_INDEX = -100
def _prompt_str(code: str, input_str: str) -> str:
ctx = make_trace_context(code, input_str)
return f"<|trace_context_start|>{ctx}<|frame_sep|><|call_sep|>{{}}<|action_sep|>def main():\n<|frame_sep|>"
def _tokenize_trace(code, input_str, tokenizer, *, max_seq_len, max_frames):
"""``(prompt_ids, trace_ids, spans)``; None to skip. Trace must terminate in
RETURN/EXCEPTION and have >=1 LINE span. Span ``(i, j)``: ``trace_ids[i]`` is
``<|line_sep|>``, ``j`` its ``<|action_sep|>``, ``trace_ids[i+1:j]`` the locals
a CODI student swaps for a latent block. Single source of membership so the SFT
baseline and CODI train on identical data."""
frames, error = ground_truth_trace(code, input_str, align_to_prompt=True, max_frames=max_frames)
if not frames or error == "frames_exceeded":
return None
if frames[-1].event not in (TraceEvent.RETURN, TraceEvent.EXCEPTION):
return None
# Qwen has no BOS (bos_token_id is None); CWM did. Prepend only if present.
bos = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
prompt_ids = bos + tokenizer.encode(_prompt_str(code, input_str), add_special_tokens=False)
trace_ids = tokenizer.encode(render_frames_to_generation(frames), add_special_tokens=False)
if len(prompt_ids) + len(trace_ids) > max_seq_len:
return None
ls = tokenizer.convert_tokens_to_ids(LINE_SEP)
asep = tokenizer.convert_tokens_to_ids(ACTION_SEP)
spans, i, n = [], 0, len(trace_ids)
while i < n:
if trace_ids[i] == ls:
j = i + 1
while j < n and trace_ids[j] != asep:
j += 1
if j == n:
break
spans.append((i, j))
i = j + 1
else:
i += 1
if not spans:
return None
return prompt_ids, trace_ids, spans
def build_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1):
"""SFT ``(input_ids, labels)`` with the prompt masked; None to skip."""
r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames)
if r is None:
return None
prompt_ids, trace_ids, _ = r
return prompt_ids + trace_ids, [IGNORE_INDEX] * len(prompt_ids) + trace_ids
def build_codi_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1):
"""Multi-span CODI example ``{prompt_ids, trace_ids, spans}``; None to skip."""
r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames)
if r is None:
return None
prompt_ids, trace_ids, spans = r
return {"prompt_ids": prompt_ids, "trace_ids": trace_ids, "spans": spans}
def _load_cache(cache_dir, n_samples):
"""Load precomputed tokenized examples (precompute.py); slice to n_samples."""
from datasets import load_from_disk
ex = list(load_from_disk(cache_dir))
return ex[:n_samples] if n_samples > 0 else ex
def build_codi_dataset(
tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None
) -> list[dict]:
"""CODI examples (prompt/reasoning/answer) over ``sources``, or a precomputed cache."""
if cache_dir:
ex = _load_cache(cache_dir, n_samples)
return [e for e in ex if len(e["prompt_ids"]) + len(e["trace_ids"]) <= max_seq_len]
rows = rows_for_sources(sources)
if n_samples > 0:
rows = rows[:n_samples]
out = []
for r in rows:
try:
out.append(build_codi_example(r["code"], r["input"], tokenizer,
max_seq_len=max_seq_len, max_frames=max_frames))
except Exception:
pass
return [ex for ex in out if ex is not None]
def build_codi_single_dataset(
tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None
) -> list[dict]:
"""Faithful single-block CODI: split each trace at its last ``<|return_sep|>`` into
``{prompt_ids, reasoning_ids, answer_ids}`` (reasoning = whole trace, answer = final
RETURN frame). Derived from the multi-span examples; no separate cache needed."""
rsep = tokenizer.convert_tokens_to_ids("<|return_sep|>")
out = []
for e in build_codi_dataset(tokenizer, sources=sources, n_samples=n_samples,
max_seq_len=max_seq_len, max_frames=max_frames, cache_dir=cache_dir):
t = e["trace_ids"]
idx = [i for i, x in enumerate(t) if x == rsep]
if not idx or idx[-1] == 0:
continue
out.append({"prompt_ids": e["prompt_ids"], "reasoning_ids": t[:idx[-1]], "answer_ids": t[idx[-1]:]})
return out
def rows_for_sources(sources):
"""Merge {id,code,input,output} rows across sources (all rows; train vs test
is split by dataset, e.g. cruxeval is held out for eval)."""
from . import sources as _src
rows = []
for name in sources:
for i, row in enumerate(_src.load_one(name)):
missing = [k for k in ("id", "code", "input", "output") if k not in row]
if missing:
raise ValueError(f"{name} row {i} missing keys: {missing}")
if not all(isinstance(row[k], str) for k in ("code", "input", "output")):
raise TypeError(f"{name} row {i} must use string code/input/output")
row = dict(row)
row["id"] = str(row["id"])
rows.append(row)
return rows
def build_dataset(
tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
max_seq_len: int = 8192, max_frames: int = -1, cache_dir: str | None = None
) -> list[tuple[list[int], list[int]]]:
"""Tokenized trace examples over ``sources``, or a precomputed cache."""
if cache_dir:
ex = _load_cache(cache_dir, n_samples)
return [(e["input_ids"], e["labels"]) for e in ex if len(e["input_ids"]) <= max_seq_len]
rows = rows_for_sources(sources)
if n_samples > 0:
rows = rows[:n_samples]
examples = (
build_example(
r["code"], r["input"], tokenizer,
max_seq_len=max_seq_len, max_frames=max_frames,
)
for r in rows
)
return [ex for ex in examples if ex is not None]