|
|
| """
|
| Prepare GD level data for modded-nanogpt training.
|
| Converts tokenized levels to .bin format compatible with the data loader.
|
| Uses multiprocessing for fast tokenization.
|
|
|
| Usage (local files):
|
| python prepare_gd_data.py --input data/gd_raw --output data/gd_levels --tokenizer tokenizer.model
|
|
|
| Usage (HuggingFace dataset):
|
| python prepare_gd_data.py --hf-repo tldne/gd-levels --hf-file levels_deduped.jsonl --output data/gd_levels --tokenizer tokenizer.model
|
| """
|
|
|
| import argparse
|
| import json
|
| import numpy as np
|
| from pathlib import Path
|
| import sentencepiece as spm
|
| from tqdm import tqdm
|
| from huggingface_hub import hf_hub_download
|
| from multiprocessing import Pool, cpu_count
|
| from functools import partial
|
| import os
|
|
|
|
|
| INCLUDE_FIELDS = [
|
| 'level_id',
|
| 'level_name',
|
| 'level_type',
|
| 'binary_version',
|
| 'description_decoded',
|
| 'song_id',
|
|
|
| ]
|
|
|
|
|
| def format_training_sample(record: dict) -> str:
|
| """
|
| Format a level record as a training string.
|
| Only includes specific metadata fields + level_string.
|
| Must match train_tokenizer.py exactly!
|
| """
|
| parts = []
|
|
|
|
|
| source_file = record.get('source_file', '')
|
| if source_file:
|
| level_id = source_file.lstrip("Level_").rstrip(".gmd2")
|
| parts.append(f"<level_id>{level_id}")
|
|
|
|
|
| for key in INCLUDE_FIELDS:
|
| if key == 'level_id':
|
| continue
|
| value = record.get(key)
|
| if value is not None and value != "":
|
| parts.append(f"<{key}>{value}")
|
|
|
|
|
| if record.get('level_string'):
|
| parts.append(f"<level_string>{record['level_string']}")
|
|
|
| return "".join(parts)
|
|
|
|
|
| def write_bin_file(tokens: np.ndarray, output_path: Path):
|
| """Write tokens to .bin file with modded-nanogpt header format."""
|
| header = np.zeros(256, dtype=np.int32)
|
| header[0] = 20240520
|
| header[1] = 1
|
| header[2] = len(tokens)
|
|
|
| with output_path.open("wb") as f:
|
| f.write(header.tobytes())
|
| f.write(tokens.astype(np.uint16).tobytes())
|
|
|
| return len(tokens)
|
|
|
|
|
|
|
| def tokenize_record(record_json: str, tokenizer_path: str) -> np.ndarray | None:
|
| """Tokenize a single record. Returns None if should be skipped."""
|
| try:
|
|
|
| if not hasattr(tokenize_record, '_sp'):
|
| tokenize_record._sp = spm.SentencePieceProcessor()
|
| tokenize_record._sp.load(tokenizer_path)
|
| sp = tokenize_record._sp
|
|
|
| record = json.loads(record_json)
|
| if not record.get('level_string'):
|
| return None
|
|
|
| text = format_training_sample(record)
|
| if not text or len(text) > 5_000_000:
|
| return None
|
|
|
| tokens = [sp.bos_id()] + sp.encode(text) + [sp.eos_id()]
|
| if len(tokens) > 10_000_000:
|
| return None
|
|
|
| return np.array(tokens, dtype=np.uint16)
|
| except:
|
| return None
|
|
|
|
|
| def process_levels_parallel(
|
| input_dir: Path,
|
| output_dir: Path,
|
| tokenizer_path: Path,
|
| shard_size: int = 100_000_000,
|
| val_ratio: float = 0.01,
|
| num_workers: int = None,
|
| ):
|
| """Process all levels using multiprocessing and create train/val shards."""
|
|
|
| if num_workers is None:
|
| num_workers = max(1, cpu_count() - 1)
|
|
|
| print(f"Using {num_workers} workers for tokenization")
|
|
|
|
|
| print(f"Loading tokenizer from {tokenizer_path}")
|
| sp = spm.SentencePieceProcessor()
|
| sp.load(str(tokenizer_path))
|
| print(f"BOS ID: {sp.bos_id()}, EOS ID: {sp.eos_id()}")
|
|
|
|
|
| jsonl_files = list(input_dir.glob("*.jsonl"))
|
| json_files = list(input_dir.glob("*.json"))
|
| level_files = jsonl_files + json_files
|
|
|
| if not level_files:
|
| jsonl_files = list(input_dir.glob("**/*.jsonl"))
|
| json_files = list(input_dir.glob("**/*.json"))
|
| level_files = jsonl_files + json_files
|
|
|
| print(f"Found {len(level_files)} input files")
|
|
|
|
|
| print("Loading records...")
|
| all_records = []
|
| for lf in tqdm(level_files, desc="Reading files"):
|
| try:
|
| if lf.suffix == ".jsonl":
|
| with open(lf, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if line:
|
| all_records.append(line)
|
| else:
|
| with open(lf, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
| records = data if isinstance(data, list) else [data]
|
| for record in records:
|
| all_records.append(json.dumps(record))
|
| except Exception as e:
|
| print(f"Error reading {lf}: {e}")
|
|
|
| print(f"Loaded {len(all_records):,} records")
|
|
|
|
|
| print(f"Tokenizing with {num_workers} workers...")
|
| tokenize_fn = partial(tokenize_record, tokenizer_path=str(tokenizer_path))
|
|
|
| all_levels = []
|
| total_tokens = 0
|
|
|
| with Pool(num_workers) as pool:
|
| results = list(tqdm(
|
| pool.imap(tokenize_fn, all_records, chunksize=100),
|
| total=len(all_records),
|
| desc="Tokenizing"
|
| ))
|
|
|
| for tokens in results:
|
| if tokens is not None:
|
| all_levels.append(tokens)
|
| total_tokens += len(tokens)
|
|
|
| print(f"Tokenized {len(all_levels):,} levels with {total_tokens:,} total tokens")
|
|
|
|
|
| np.random.seed(42)
|
| indices = np.random.permutation(len(all_levels))
|
|
|
| val_count = max(1, int(len(all_levels) * val_ratio))
|
| val_indices = indices[:val_count]
|
| train_indices = indices[val_count:]
|
|
|
| print(f"Train: {len(train_indices):,} levels, Val: {len(val_indices):,} levels")
|
|
|
|
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| train_tokens = np.concatenate([all_levels[i] for i in tqdm(train_indices, desc="Concatenating train")])
|
| num_train_shards = max(1, len(train_tokens) // shard_size)
|
|
|
| print(f"Writing {num_train_shards} train shards...")
|
| shard_size_actual = len(train_tokens) // num_train_shards
|
| for i in tqdm(range(num_train_shards), desc="Writing train shards"):
|
| start = i * shard_size_actual
|
| end = start + shard_size_actual if i < num_train_shards - 1 else len(train_tokens)
|
| shard = train_tokens[start:end]
|
|
|
| output_path = output_dir / f"train_{i:04d}.bin"
|
| n = write_bin_file(shard, output_path)
|
|
|
|
|
| val_tokens = np.concatenate([all_levels[i] for i in tqdm(val_indices, desc="Concatenating val")])
|
| num_val_shards = max(1, len(val_tokens) // shard_size)
|
|
|
| print(f"Writing {num_val_shards} val shards...")
|
| shard_size_actual = len(val_tokens) // num_val_shards
|
| for i in range(num_val_shards):
|
| start = i * shard_size_actual
|
| end = start + shard_size_actual if i < num_val_shards - 1 else len(val_tokens)
|
| shard = val_tokens[start:end]
|
|
|
| output_path = output_dir / f"val_{i:04d}.bin"
|
| n = write_bin_file(shard, output_path)
|
|
|
|
|
| print("\n=== Summary ===")
|
| print(f"Train tokens: {len(train_tokens):,}")
|
| print(f"Val tokens: {len(val_tokens):,}")
|
| print(f"Total tokens: {len(train_tokens) + len(val_tokens):,}")
|
| print(f"Output dir: {output_dir}")
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Prepare GD levels for modded-nanogpt")
|
|
|
|
|
| input_group = parser.add_mutually_exclusive_group(required=True)
|
| input_group.add_argument("--input", "-i", type=Path, help="Input directory/file with level JSONs/JSONL")
|
| input_group.add_argument("--hf-repo", type=str, help="HuggingFace repo ID (e.g., tldne/gd-levels)")
|
|
|
| parser.add_argument("--hf-file", type=str, default="levels_deduped.jsonl", help="File to download from HF repo")
|
| parser.add_argument("--output", "-o", type=Path, required=True, help="Output directory for .bin files")
|
| parser.add_argument("--tokenizer", "-t", type=Path, required=True, help="Path to tokenizer.model")
|
| parser.add_argument("--shard-size", type=int, default=100_000_000, help="Tokens per shard")
|
| parser.add_argument("--val-ratio", type=float, default=0.01, help="Validation split ratio")
|
| parser.add_argument("--workers", "-w", type=int, default=None, help="Number of workers (default: cpu_count - 1)")
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| if args.hf_repo:
|
| print(f"Downloading {args.hf_file} from {args.hf_repo}...")
|
| downloaded_path = hf_hub_download(
|
| repo_id=args.hf_repo,
|
| filename=args.hf_file,
|
| repo_type="dataset",
|
| )
|
| print(f"Downloaded to: {downloaded_path}")
|
| import tempfile
|
| import shutil
|
| temp_dir = Path(tempfile.mkdtemp())
|
| shutil.copy(downloaded_path, temp_dir / args.hf_file)
|
| input_dir = temp_dir
|
| else:
|
| input_dir = args.input
|
| temp_dir = None
|
|
|
| process_levels_parallel(
|
| input_dir=input_dir,
|
| output_dir=args.output,
|
| tokenizer_path=args.tokenizer,
|
| shard_size=args.shard_size,
|
| val_ratio=args.val_ratio,
|
| num_workers=args.workers,
|
| )
|
|
|
|
|
| if temp_dir:
|
| import shutil
|
| shutil.rmtree(temp_dir)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|