| import torch |
| import torch.optim as optim |
| from transformers import AutoTokenizer |
| from tqdm import tqdm |
|
|
| from src.config import ModelConfig, TrainConfig |
| from src.models.autoencoder import SphericalAutoencoder |
| from src.models.dit import PatchedFlowDiT |
| from src.trainer import Trainer |
| from src.utils.data_utils import prepare_data |
| from src.utils.sandbox import SafeSandbox |
| from src.search import DiffuMCTS |
|
|
| def inference(ae, flow, src_ids, src_mask, device, steps=10): |
| ae.eval(); flow.eval() |
| with torch.no_grad(): |
| |
| z_curr = ae.encode(src_ids, src_mask) |
| z_cond = z_curr.clone() |
| |
| dt = 1.0 / steps |
| for i in range(steps): |
| t = torch.ones(z_curr.shape[0], device=device) * (i / steps) |
| v = flow(z_curr, t, condition=z_cond).float() |
| z_curr = z_curr + v * dt |
| |
| z_curr = torch.nn.functional.normalize(z_curr, p=2, dim=-1) |
| logits = ae.decode(z_curr) |
| return torch.argmax(logits, dim=-1) |
|
|
| def evaluate_on_humaneval(ae, flow, tokenizer, device, num_samples=20): |
| """ |
| 在 HumanEvalPack 上进行真实的执行测试 |
| """ |
| print("\n>>> Starting Evaluation on HumanEvalPack (Real Execution)...") |
| loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test") |
| sandbox = SafeSandbox() |
| |
| passed = 0 |
| total = 0 |
| |
| |
| for i, batch in enumerate(tqdm(loader, total=num_samples)): |
| if i >= num_samples: break |
| |
| src = batch['src_ids'].to(device) |
| mask = batch['src_mask'].to(device) |
| test_code = batch['test_code'][0] |
| entry_point = batch['entry_point'][0] |
| |
| |
| out_ids = inference(ae, flow, src, mask, device) |
| gen_code = tokenizer.decode(out_ids[0], skip_special_tokens=True) |
| |
| |
| is_pass, msg = sandbox.run(gen_code, test_code, entry_point) |
| |
| if is_pass: |
| passed += 1 |
| |
| total += 1 |
| |
| |
| if i == 0: |
| print(f"\n[Case 0] Pass: {is_pass}") |
| print(f"Error: {msg}") |
| print(f"Generated:\n{gen_code[:200]}...") |
|
|
| print(f"\n=== Eval Result ===") |
| print(f"Pass@1: {passed}/{total} = {passed/total*100:.2f}%") |
|
|
| def evaluate_with_mcts(ae, flow, tokenizer, device, num_samples=20): |
| """ |
| 使用 Diffu-MCTS 进行强化评估 |
| """ |
| print(f"\n>>> Starting Diffu-MCTS Evaluation (samples={num_samples})...") |
| |
| |
| loader = prepare_data("humanevalpack", tokenizer, 512, 1, split="test") |
| sandbox = SafeSandbox(timeout=2.0) |
| |
| |
| mcts = DiffuMCTS(ae, flow, tokenizer, sandbox, device, config=None) |
| mcts.num_branches = 8 |
| |
| passed = 0 |
| total = 0 |
| |
| |
| for i, batch in enumerate(tqdm(loader, total=num_samples)): |
| if i >= num_samples: break |
| |
| |
| |
| |
| |
| src_ids = batch['src_ids'].to(device) |
| buggy_code = tokenizer.decode(src_ids[0], skip_special_tokens=True) |
| |
| test_code = batch['test_code'][0] |
| entry_point = batch['entry_point'][0] |
| |
| |
| fixed_code, is_success = mcts.solve(buggy_code, test_code, entry_point) |
| |
| if is_success: |
| passed += 1 |
| |
| total += 1 |
| |
| |
| if i == 0: |
| print(f"\n[Case 0]") |
| print(f"Buggy:\n{buggy_code[:100]}...") |
| print(f"Fixed:\n{fixed_code[:100]}...") |
| print(f"Result: {'✅ PASS' if is_success else '❌ FAIL'}") |
|
|
| print(f"\n=== MCTS Results ===") |
| print(f"Pass@1 (with Search K={mcts.num_branches}): {passed}/{total} = {passed/total*100:.2f}%") |
|
|
| def main(): |
| m_cfg = ModelConfig() |
| t_cfg = TrainConfig(batch_size=8, grad_accum_steps=4) |
| |
| tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name, trust_remote_code=True) |
| if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| train_loader = prepare_data("codexglue", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train") |
| |
| ae = SphericalAutoencoder(m_cfg).to(t_cfg.device).float() |
| |
| if ae.encoder.config.pad_token_id is None: |
| ae.encoder.config.pad_token_id = tokenizer.pad_token_id |
| |
| flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float() |
| |
| trainer = Trainer(ae, flow, t_cfg, train_loader) |
| |
| |
| |
| |
| print("\n>>> Training AE on CodeXGLUE...") |
| opt_ae = optim.AdamW(filter(lambda p: p.requires_grad, ae.parameters()), lr=t_cfg.lr_ae) |
| for epoch in range(t_cfg.num_epochs_ae): |
| loss = trainer.train_ae(opt_ae) |
| print(f"AE Epoch {epoch}: Loss {loss:.4f}") |
| |
| |
| print("\n>>> Training Flow Matching on CodeXGLUE...") |
| opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow) |
| for epoch in range(t_cfg.num_epochs_flow): |
| loss = trainer.train_flow(opt_flow) |
| print(f"Flow Epoch {epoch}: Loss {loss:.4f}") |
| |
| |
| evaluate_on_humaneval(ae, flow, tokenizer, t_cfg.device) |
| |
| evaluate_with_mcts(ae, flow, tokenizer, t_cfg.device, num_samples=50) |
|
|
| if __name__ == "__main__": |
| main() |