| """Fix branches with wrong training_step + training_val_bpb in config.json.""" |
| import json, tempfile, os |
| from huggingface_hub import HfApi |
|
|
| |
| FIX = { |
| "step-2000": (2000, 0.987364), |
| "step-4000": (4000, 0.954657), |
| "step-6000": (6000, 0.949316), |
| "step-8000": (8000, 0.943905), |
| "step-10000": (10000, 0.935804), |
| |
| "step-12000": (12000, 0.931519), |
| } |
| api = HfApi() |
| repo = "cognica/Cognica-PoE-v1.0-3B-base" |
|
|
| for branch, (step, val_bpb) in FIX.items(): |
| print(f"\n=== {branch}: step={step}, val_bpb={val_bpb} ===") |
| with tempfile.TemporaryDirectory() as tmp: |
| cfg_path = api.hf_hub_download( |
| repo_id=repo, filename="config.json", revision=branch, local_dir=tmp, |
| ) |
| with open(cfg_path, "r") as f: |
| cfg = json.load(f) |
| cfg["training_step"] = step |
| cfg["training_val_bpb"] = val_bpb |
| with open(cfg_path, "w") as f: |
| json.dump(cfg, f, indent=2) |
| api.upload_file( |
| path_or_fileobj=cfg_path, path_in_repo="config.json", |
| repo_id=repo, revision=branch, |
| commit_message=f"Fix training_step/val_bpb for {branch}", |
| ) |
| print(f" pushed") |
|
|