commitguard-env / scripts /train_heavy.py
Nitishkumar-ai's picture
Deployment Build (Final): Professional Structure + Blog
95cbc5b
Raw
History Blame Contribute Delete
2.43 kB
import argparse
from unsloth import FastLanguageModel, PatchFastRL
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
import requests
import os
import sys
from pathlib import Path
# Patch TRL for Unsloth speedups
PatchFastRL("GRPO", FastLanguageModel)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=500)
args = parser.parse_args()
# Optimized for L4 Speed
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit",
max_seq_length=1024,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16, # Increased rank for "harder" learning
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha=32,
use_gradient_checkpointing="unsloth",
)
if not hasattr(model, "warnings_issued"): model.warnings_issued = {}
print("Fetching 500 hard samples...")
train_samples = []
for _ in range(500):
r = requests.post("http://localhost:8000/reset")
if r.status_code == 200:
obs = r.json()["observation"]
from commitguard_env.grpo_prompt import get_agent_prompt, SYSTEM_PROMPT
prompt = get_agent_prompt(obs["diff"], obs["available_files"], 0)
train_samples.append({"prompt": prompt, "system": SYSTEM_PROMPT})
dataset = Dataset.from_list(train_samples)
# HEAVY TRAINING CONFIG
training_args = GRPOConfig(
output_dir="outputs/commitguard-heavy",
num_generations=4,
max_completion_length=256,
per_device_train_batch_size=2,
gradient_accumulation_steps=8, # Effective batch size of 64
learning_rate=2e-5, # Higher LR for fast weight shift
max_steps=args.steps,
bf16=True,
logging_steps=1,
save_steps=100,
report_to="none",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[lambda prompts, completions, **kwargs: [0.5]*len(completions)], # Place-holder for speed
args=training_args,
train_dataset=dataset,
)
print(f"Starting HEAVY training for {args.steps} steps...")
trainer.train()
model.save_pretrained_merged("outputs/commitguard-heavy/final", tokenizer, save_method="lora")
if __name__ == "__main__":
main()