banhmi-gemma4-e4b / scripts /training_logger.py
bradduy's picture
Add Unsloth training pipeline (train, evaluate, export, prepare_data, training_logger)
4942b80 verified
#!/usr/bin/env python3
"""
Training logger — structured logging for fine-tuning runs.
Provides a TrainerCallback that streams training metrics (loss, LR, memory,
throughput) to structured files during training. Works with SFTTrainer.
Files produced in {output_dir}/logs/:
- training_log.csv — every logging step (loss, lr, memory, tokens/s)
- training_summary.json — final summary with config + results
- loss_curve.csv — simplified loss-only for plotting
Usage as callback:
from training_logger import TrainingLogger
logger = TrainingLogger(output_dir="outputs/logs", experiment_name="exp09")
trainer = SFTTrainer(..., callbacks=[logger])
trainer.train()
logger.save_summary(trainer_stats, args_dict)
Usage standalone (parse existing trainer logs):
python scripts/training_logger.py --log-dir outputs/ --output outputs/logs/
"""
import argparse
import csv
import json
import os
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime
from typing import Optional
import torch
from transformers import TrainerCallback, TrainerState, TrainerControl, TrainingArguments
@dataclass
class StepMetrics:
step: int
epoch: float
loss: float
learning_rate: float
grad_norm: float = 0.0
gpu_memory_mb: float = 0.0
gpu_memory_pct: float = 0.0
tokens_per_sec: float = 0.0
elapsed_sec: float = 0.0
timestamp: str = ""
@dataclass
class TrainingSummary:
experiment_name: str
model_name: str = ""
dataset_name: str = ""
dataset_size: int = 0
lora_rank: int = 0
learning_rate: float = 0.0
num_epochs: float = 0.0
total_steps: int = 0
batch_size: int = 0
grad_accum: int = 0
weight_decay: float = 0.0
warmup_steps: int = 0
final_loss: float = 0.0
min_loss: float = float("inf")
min_loss_step: int = 0
total_runtime_sec: float = 0.0
peak_gpu_memory_mb: float = 0.0
avg_tokens_per_sec: float = 0.0
total_tokens_trained: int = 0
timestamp: str = ""
status: str = "running" # running, completed, crashed
# Loss at epoch boundaries
epoch_losses: dict = field(default_factory=dict)
class TrainingLogger(TrainerCallback):
"""
HuggingFace TrainerCallback that logs structured training metrics.
Attach to SFTTrainer via callbacks=[TrainingLogger(...)].
"""
def __init__(self, output_dir: str = "outputs/logs",
experiment_name: str = "experiment"):
self.output_dir = output_dir
self.experiment_name = experiment_name
self.start_time = None
self.step_metrics: list[StepMetrics] = []
self.summary = TrainingSummary(experiment_name=experiment_name)
self._csv_writer = None
self._csv_file = None
self._last_epoch = 0.0
os.makedirs(output_dir, exist_ok=True)
def on_train_begin(self, args: TrainingArguments, state: TrainerState,
control: TrainerControl, **kwargs):
self.start_time = time.time()
self.summary.timestamp = datetime.now().isoformat()
self.summary.total_steps = state.max_steps
self.summary.batch_size = args.per_device_train_batch_size
self.summary.grad_accum = args.gradient_accumulation_steps
self.summary.learning_rate = args.learning_rate
self.summary.weight_decay = args.weight_decay
self.summary.warmup_steps = args.warmup_steps
self.summary.num_epochs = args.num_train_epochs
# Open CSV for streaming writes
csv_path = os.path.join(self.output_dir, "training_log.csv")
self._csv_file = open(csv_path, "w", newline="")
fieldnames = [
"step", "epoch", "loss", "learning_rate", "grad_norm",
"gpu_memory_mb", "gpu_memory_pct", "tokens_per_sec",
"elapsed_sec", "timestamp",
]
self._csv_writer = csv.DictWriter(self._csv_file, fieldnames=fieldnames)
self._csv_writer.writeheader()
print(f"[TrainingLogger] Logging to {self.output_dir}/")
def on_log(self, args: TrainingArguments, state: TrainerState,
control: TrainerControl, logs: dict = None, **kwargs):
if logs is None or "loss" not in logs:
return
elapsed = time.time() - self.start_time if self.start_time else 0
epoch = logs.get("epoch", state.epoch or 0)
# GPU memory
gpu_mem_mb = 0.0
gpu_mem_pct = 0.0
if torch.cuda.is_available():
gpu_mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
gpu_total = torch.cuda.get_device_properties(0).total_mem / (1024 ** 2)
gpu_mem_pct = (gpu_mem_mb / gpu_total * 100) if gpu_total > 0 else 0
metrics = StepMetrics(
step=state.global_step,
epoch=round(epoch, 4),
loss=round(logs.get("loss", 0), 6),
learning_rate=logs.get("learning_rate", 0),
grad_norm=logs.get("grad_norm", 0),
gpu_memory_mb=round(gpu_mem_mb, 1),
gpu_memory_pct=round(gpu_mem_pct, 1),
tokens_per_sec=round(
state.global_step * args.per_device_train_batch_size
* args.gradient_accumulation_steps / max(elapsed, 1), 1
),
elapsed_sec=round(elapsed, 1),
timestamp=datetime.now().isoformat(),
)
self.step_metrics.append(metrics)
# Stream to CSV
if self._csv_writer:
self._csv_writer.writerow(asdict(metrics))
self._csv_file.flush()
# Track min loss
if metrics.loss < self.summary.min_loss:
self.summary.min_loss = metrics.loss
self.summary.min_loss_step = metrics.step
# Track epoch boundary losses
current_epoch_int = int(epoch)
last_epoch_int = int(self._last_epoch)
if current_epoch_int > last_epoch_int and current_epoch_int > 0:
self.summary.epoch_losses[str(last_epoch_int + 1)] = metrics.loss
self._last_epoch = epoch
# Track peak GPU memory
if gpu_mem_mb > self.summary.peak_gpu_memory_mb:
self.summary.peak_gpu_memory_mb = gpu_mem_mb
def on_train_end(self, args: TrainingArguments, state: TrainerState,
control: TrainerControl, **kwargs):
if self._csv_file:
self._csv_file.close()
self._csv_file = None
self.summary.status = "completed"
def save_summary(self, trainer_stats=None, config: dict = None):
"""
Save final training summary. Call after trainer.train() completes.
Args:
trainer_stats: Return value from trainer.train()
config: Dict of training config (model name, dataset, etc.)
"""
if trainer_stats:
metrics = trainer_stats.metrics
self.summary.final_loss = metrics.get("train_loss", 0)
self.summary.total_runtime_sec = metrics.get("train_runtime", 0)
self.summary.total_steps = metrics.get("train_steps", self.summary.total_steps)
if config:
self.summary.model_name = config.get("model_name", "")
self.summary.dataset_name = config.get("dataset_name", "")
self.summary.dataset_size = config.get("dataset_size", 0)
self.summary.lora_rank = config.get("lora_rank", 0)
if self.step_metrics:
total_elapsed = self.step_metrics[-1].elapsed_sec
if total_elapsed > 0:
self.summary.avg_tokens_per_sec = round(
self.step_metrics[-1].tokens_per_sec, 1
)
# Final epoch loss
last = self.step_metrics[-1]
epoch_int = int(last.epoch)
if str(epoch_int) not in self.summary.epoch_losses:
self.summary.epoch_losses[str(epoch_int)] = last.loss
# Save JSON summary
summary_path = os.path.join(self.output_dir, "training_summary.json")
with open(summary_path, "w") as f:
json.dump(asdict(self.summary), f, indent=2, default=str)
print(f"[TrainingLogger] Summary: {summary_path}")
# Save simplified loss curve CSV
loss_path = os.path.join(self.output_dir, "loss_curve.csv")
with open(loss_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["step", "epoch", "loss"])
for m in self.step_metrics:
writer.writerow([m.step, m.epoch, m.loss])
print(f"[TrainingLogger] Loss curve: {loss_path}")
# Print summary
print(f"\n{'=' * 60}")
print(f"Training Summary: {self.experiment_name}")
print(f"{'=' * 60}")
print(f" Final loss: {self.summary.final_loss:.6f}")
print(f" Min loss: {self.summary.min_loss:.6f} "
f"(step {self.summary.min_loss_step})")
print(f" Runtime: {self.summary.total_runtime_sec:.1f}s")
print(f" Peak GPU: {self.summary.peak_gpu_memory_mb:.0f} MB")
print(f" Avg throughput: {self.summary.avg_tokens_per_sec:.1f} steps/s")
if self.summary.epoch_losses:
print(f" Epoch losses: {self.summary.epoch_losses}")
print(f" Status: {self.summary.status}")
def parse_existing_logs(log_dir: str, output_dir: str):
"""Parse trainer_state.json from an existing training run."""
state_path = os.path.join(log_dir, "trainer_state.json")
if not os.path.exists(state_path):
# Try to find it in a checkpoint
import glob
candidates = glob.glob(os.path.join(log_dir, "checkpoint-*/trainer_state.json"))
if candidates:
state_path = sorted(candidates)[-1] # latest checkpoint
else:
print(f"No trainer_state.json found in {log_dir}")
return
print(f"Parsing: {state_path}")
with open(state_path) as f:
state = json.load(f)
log_history = state.get("log_history", [])
if not log_history:
print("No log history found.")
return
os.makedirs(output_dir, exist_ok=True)
# Write CSV
csv_path = os.path.join(output_dir, "training_log.csv")
with open(csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["step", "epoch", "loss", "learning_rate", "grad_norm"])
for entry in log_history:
if "loss" in entry:
writer.writerow([
entry.get("step", 0),
entry.get("epoch", 0),
entry.get("loss", 0),
entry.get("learning_rate", 0),
entry.get("grad_norm", 0),
])
# Write loss curve
loss_path = os.path.join(output_dir, "loss_curve.csv")
with open(loss_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["step", "epoch", "loss"])
for entry in log_history:
if "loss" in entry:
writer.writerow([
entry.get("step", 0),
entry.get("epoch", 0),
entry.get("loss", 0),
])
print(f"Parsed {len(log_history)} log entries -> {output_dir}/")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Parse existing trainer logs into structured format"
)
parser.add_argument("--log-dir", type=str, required=True,
help="Directory containing trainer_state.json")
parser.add_argument("--output", type=str, default="outputs/logs",
help="Output directory for parsed logs")
args = parser.parse_args()
parse_existing_logs(args.log_dir, args.output)