WaveCut's picture
Upload folder using huggingface_hub
f1af99e verified
#!/usr/bin/env python
"""
Shrink FLUX.2 text encoder (Mistral3) down to a 7-layer student.
- Teacher: Mistral3ForConditionalGeneration from FLUX.2 text_encoder
(by default: black-forest-labs/FLUX.2-dev, subfolder="text_encoder")
- Student: same architecture family (Mistral3), but with:
* num_text_layers = 7 (from 40)
* vision depth kept intact (24 -> 24)
Weights are streamed shard-by-shard so the full teacher never lives in memory.
Text layers are mapped from strategically chosen teacher layers (spread across depth).
"""
import argparse
import gc
import json
import os
from copy import deepcopy
from typing import Dict, List, Optional, Set, Tuple
import torch
from huggingface_hub import HfFolder, file_exists, hf_hub_download
from safetensors.torch import safe_open
from transformers import AutoConfig, GenerationConfig, Mistral3ForConditionalGeneration
from transformers.modeling_utils import init_empty_weights
from transformers.utils.hub import get_checkpoint_shard_files
# Some buffers are generated at load time and are absent from HF weight files.
MODEL_NAME = "FLUX2-TE-Trimmed7L-Research"
EXPECTED_MISSING = {
"model.language_model.rotary_emb.inv_freq",
"model.vision_tower.patch_positional_embedding.inv_freq",
}
def write_readme(
output_dir: str,
text_layers: int,
vision_layers: int,
model_name: str,
) -> None:
"""Write README with concise technical details matching the shipped artifact."""
lines = [
f"# {model_name}",
"",
"**WARNING: Experimental research artifact. Do NOT use in production.**",
"",
"## What this is",
f"- Model name: {model_name}",
f"- Text transformer depth: {text_layers} layers",
f"- Vision tower depth: {vision_layers} layers (unchanged)",
"",
"## Caveats",
"- Intended for experimentation and further distillation, not end-user deployment.",
"- Rotary and patch positional inv_freq buffers are created at load time (not stored).",
"",
"## How it was built",
f"- Shrank teacher text path to {text_layers} layers (evenly spaced) while keeping vision depth intact.",
"- Stream-copied only needed tensors; other text layers were pruned.",
]
readme_path = os.path.join(output_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
# ---------- helpers for layer selection ----------
def choose_layer_indices(total_layers: int, num_keep: int) -> List[int]:
"""
Choose num_keep layer indices from [0, total_layers-1], roughly evenly spaced.
Always includes first (0) and last (total_layers-1), assuming num_keep >= 2.
"""
if num_keep <= 0:
raise ValueError("num_keep must be > 0")
if num_keep > total_layers:
raise ValueError(
f"num_keep ({num_keep}) cannot exceed total_layers ({total_layers})"
)
if num_keep == total_layers:
return list(range(total_layers))
if num_keep == 1:
return [total_layers - 1]
indices = set()
for i in range(num_keep):
# even spacing across [0, total_layers-1]
pos = int(round(i * (total_layers - 1) / (num_keep - 1)))
indices.add(max(0, min(total_layers - 1, pos)))
# If rounding collapsed some positions, fill up with nearest unused layers
while len(indices) < num_keep:
for i in range(total_layers):
if i not in indices:
indices.add(i)
if len(indices) == num_keep:
break
return sorted(indices)
def map_teacher_to_student_key(
teacher_key: str, text_layer_map: Dict[int, int]
) -> Tuple[Optional[str], bool]:
"""
Map a teacher tensor name to the corresponding student tensor name.
Returns:
(target_key, pruned_text_layer)
- target_key is None when the teacher key belongs to a text layer that
is not selected for the student.
- pruned_text_layer is True when the skip reason is text-layer pruning.
"""
# Text encoder layers: teacher uses "language_model.model.layers.N"; student uses "model.language_model.layers.N'".
text_prefix = "language_model.model.layers."
if teacher_key.startswith(text_prefix):
remainder = teacher_key[len(text_prefix) :]
if "." not in remainder:
return None, False
layer_str, suffix = remainder.split(".", 1)
try:
old_idx = int(layer_str)
except ValueError:
return None, False
if old_idx not in text_layer_map:
return None, True
new_idx = text_layer_map[old_idx]
return f"model.language_model.layers.{new_idx}.{suffix}", False
# Other language_model.* tensors (embed_tokens, norm, rotary_emb, etc.).
lm_prefix = "language_model.model."
if teacher_key.startswith(lm_prefix):
remainder = teacher_key[len(lm_prefix) :]
return f"model.language_model.{remainder}", False
# LM head key flattens to top-level "lm_head" in student.
lm_head_prefix = "language_model.lm_head."
if teacher_key.startswith(lm_head_prefix):
remainder = teacher_key[len(lm_head_prefix) :]
return f"lm_head.{remainder}", False
# Vision tower tensors live under model.vision_tower.* in the student.
if teacher_key.startswith("vision_tower."):
return f"model.{teacher_key}", False
# Multi-modal projector tensors live under model.multi_modal_projector.*
if teacher_key.startswith("multi_modal_projector."):
return f"model.{teacher_key}", False
return teacher_key, False
def resolve_checkpoint_files(
teacher_repo: str, teacher_subfolder: str
) -> Tuple[List[str], Optional[str]]:
"""
Find checkpoint shard files (or a single weight file) for a HF repo/subfolder.
Prefers safetensors index, then bin index, then single-file weights.
"""
token = os.environ.get("HF_TOKEN") or HfFolder.get_token()
path_prefix = f"{teacher_subfolder}/" if teacher_subfolder else ""
index_candidates = [
"model.safetensors.index.json",
"pytorch_model.bin.index.json",
]
fallback_errors = []
for cand in index_candidates:
repo_path = f"{path_prefix}{cand}"
try:
if not file_exists(teacher_repo, repo_path, token=token):
continue
# Prefer the transformers helper first (downloads shards lazily)
try:
ckpt_files, _ = get_checkpoint_shard_files(
teacher_repo,
index_filename=cand,
subfolder=teacher_subfolder,
token=token,
)
return ckpt_files, cand
except Exception as e:
fallback_errors.append(
f"{cand} via transformers: {type(e).__name__}: {e}"
)
# Fallback: download index via hf_hub and parse shard list ourselves
try:
index_path = hf_hub_download(
teacher_repo,
filename=cand,
subfolder=teacher_subfolder,
token=token,
)
with open(index_path, "r", encoding="utf-8") as f:
index_data = json.load(f)
shard_filenames = sorted(set(index_data.get("weight_map", {}).values()))
if not shard_filenames:
raise RuntimeError("index file contained no weight_map entries")
shard_paths = [
hf_hub_download(
teacher_repo,
filename=shard,
subfolder=teacher_subfolder,
token=token,
)
for shard in shard_filenames
]
return shard_paths, cand
except Exception as e: # pragma: no cover - best effort fallback
fallback_errors.append(
f"{cand} via hf_hub_download: {type(e).__name__}: {e}"
)
except Exception as e:
fallback_errors.append(f"HEAD {repo_path}: {type(e).__name__}: {e}")
single_candidates = ["model.safetensors", "pytorch_model.bin"]
for cand in single_candidates:
repo_path = f"{path_prefix}{cand}"
try:
if file_exists(teacher_repo, repo_path, token=token):
local_path = hf_hub_download(
teacher_repo,
filename=cand,
subfolder=teacher_subfolder,
token=token,
)
return [local_path], None
except Exception:
continue
err_detail = (
"; ".join(fallback_errors)
if fallback_errors
else "no matching index or single-file weights found"
)
raise RuntimeError(
"Could not locate checkpoint files in the specified repo/subfolder. "
f"Details: {err_detail}"
)
def stream_copy_teacher_to_student(
checkpoint_files: List[str],
student: Mistral3ForConditionalGeneration,
text_indices: List[int],
torch_dtype: Optional[torch.dtype],
) -> Tuple[Set[str], List[str], List[str], List[str]]:
"""
Load teacher weights shard-by-shard and copy only what the student needs.
Returns:
loaded_keys: names successfully copied into the student
skipped_teacher_keys: teacher keys that exist but are not present in the student
pruned_text_keys: teacher text-layer keys that were intentionally dropped
shape_mismatches: keys that were found but skipped because of shape mismatch
"""
text_layer_map = {old_idx: new_idx for new_idx, old_idx in enumerate(text_indices)}
target_tensors: Dict[str, torch.Tensor] = {}
target_tensors.update({k: v for k, v in student.named_parameters()})
target_tensors.update({k: v for k, v in student.named_buffers()})
loaded_keys = set()
skipped_teacher_keys: List[str] = []
pruned_text_keys: List[str] = []
shape_mismatches: List[str] = []
auto_dtype = torch_dtype is None
chosen_auto_dtype: Optional[torch.dtype] = None
for shard_path in checkpoint_files:
shard_name = os.path.basename(shard_path)
print(f"[info] Streaming shard '{shard_name}'...")
loaded_in_shard = 0
if shard_path.endswith(".safetensors"):
with safe_open(shard_path, framework="pt", device="cpu") as f:
for teacher_key in f.keys():
target_key, pruned = map_teacher_to_student_key(
teacher_key, text_layer_map
)
if pruned:
pruned_text_keys.append(teacher_key)
continue
if target_key is None:
continue
if target_key not in target_tensors:
skipped_teacher_keys.append(f"{teacher_key} -> {target_key}")
continue
tensor = f.get_tensor(teacher_key)
target_tensor = target_tensors[target_key]
if target_tensor.shape != tensor.shape:
shape_mismatches.append(
f"{teacher_key} -> {target_key}: "
f"teacher {tuple(tensor.shape)} vs student {tuple(target_tensor.shape)}"
)
continue
if auto_dtype and chosen_auto_dtype is None:
chosen_auto_dtype = tensor.dtype
student.to(dtype=chosen_auto_dtype)
target_tensor = target_tensors[target_key]
if torch_dtype is not None and tensor.dtype != torch_dtype:
tensor = tensor.to(dtype=torch_dtype)
with torch.no_grad():
target_tensor.copy_(tensor)
loaded_keys.add(target_key)
loaded_in_shard += 1
else:
# Fallback for legacy .bin shards (loads the whole shard into memory).
shard_state = torch.load(shard_path, map_location="cpu")
for teacher_key, tensor in shard_state.items():
target_key, pruned = map_teacher_to_student_key(
teacher_key, text_layer_map
)
if pruned:
pruned_text_keys.append(teacher_key)
continue
if target_key is None:
continue
if target_key not in target_tensors:
skipped_teacher_keys.append(f"{teacher_key} -> {target_key}")
continue
target_tensor = target_tensors[target_key]
if target_tensor.shape != tensor.shape:
shape_mismatches.append(
f"{teacher_key} -> {target_key}: "
f"teacher {tuple(tensor.shape)} vs student {tuple(target_tensor.shape)}"
)
continue
if auto_dtype and chosen_auto_dtype is None:
chosen_auto_dtype = tensor.dtype
student.to(dtype=chosen_auto_dtype)
target_tensor = target_tensors[target_key]
if torch_dtype is not None and tensor.dtype != torch_dtype:
tensor = tensor.to(dtype=torch_dtype)
with torch.no_grad():
target_tensor.copy_(tensor)
loaded_keys.add(target_key)
loaded_in_shard += 1
del shard_state
gc.collect()
print(f"[info] Loaded {loaded_in_shard} tensors from '{shard_name}'.")
return loaded_keys, skipped_teacher_keys, pruned_text_keys, shape_mismatches
def stream_copy_to_disk(
checkpoint_files: List[str],
text_indices: List[int],
torch_dtype: Optional[torch.dtype],
student_config: AutoConfig,
output_dir: str,
max_shard_size_bytes: int,
) -> Tuple[
List[int],
List[int],
List[str],
List[str],
List[str],
Optional[str],
Dict[str, str],
int,
int,
]:
"""
Stream teacher shards and write student shards directly to disk to reduce RAM usage.
Returns tuple matching the info parts of build_student_from_teacher, plus
weight_map, total_params, total_size_bytes for the generated index file.
"""
# Build a meta-device student to know expected shapes without allocating real storage.
with init_empty_weights():
meta_student = Mistral3ForConditionalGeneration(student_config)
student_param_shapes = {k: p.shape for k, p in meta_student.named_parameters()}
student_buffer_shapes = {k: b.shape for k, b in meta_student.named_buffers()}
student_shapes = {**student_param_shapes, **student_buffer_shapes}
text_layer_map = {old_idx: new_idx for new_idx, old_idx in enumerate(text_indices)}
loaded_keys: Set[str] = set()
skipped_teacher_keys: List[str] = []
pruned_text_keys: List[str] = []
shape_mismatches: List[str] = []
weight_map: Dict[str, str] = {}
total_params = 0
total_size_bytes = 0
os.makedirs(output_dir, exist_ok=True)
current_tensors: Dict[str, torch.Tensor] = {}
current_size = 0
shard_idx = 1
current_shard_name = f"model-shard-{shard_idx:05d}.safetensors"
def flush_current() -> None:
nonlocal shard_idx, current_tensors, current_size, current_shard_name
if not current_tensors:
return
out_path = os.path.join(output_dir, current_shard_name)
print(
f"[info] writing {len(current_tensors)} tensors to '{out_path}' (size ~{current_size / 1e9:.2f} GB)"
)
from safetensors.torch import save_file
save_file(current_tensors, out_path)
shard_idx += 1
current_shard_name = f"model-shard-{shard_idx:05d}.safetensors"
current_tensors = {}
current_size = 0
for shard_path in checkpoint_files:
shard_name = os.path.basename(shard_path)
print(f"[info] Streaming shard '{shard_name}' -> disk...")
found_in_shard = 0
if shard_path.endswith(".safetensors"):
with safe_open(shard_path, framework="pt", device="cpu") as f:
for teacher_key in f.keys():
target_key, pruned = map_teacher_to_student_key(
teacher_key, text_layer_map
)
if pruned:
pruned_text_keys.append(teacher_key)
continue
if target_key is None:
continue
if target_key not in student_shapes:
skipped_teacher_keys.append(
f"{teacher_key} -> {target_key} (no student tensor)"
)
continue
tensor = f.get_tensor(teacher_key)
if tensor.shape != student_shapes[target_key]:
shape_mismatches.append(
f"{teacher_key} -> {target_key}: teacher {tuple(tensor.shape)} "
f"vs student {tuple(student_shapes[target_key])}"
)
continue
if torch_dtype is not None and tensor.dtype != torch_dtype:
tensor = tensor.to(dtype=torch_dtype)
t_size = tensor.numel() * tensor.element_size()
if current_size + t_size > max_shard_size_bytes and current_tensors:
flush_current()
current_tensors[target_key] = tensor
weight_map[target_key] = current_shard_name
current_size += t_size
loaded_keys.add(target_key)
total_params += tensor.numel()
total_size_bytes += t_size
found_in_shard += 1
else:
shard_state = torch.load(shard_path, map_location="cpu")
for teacher_key, tensor in shard_state.items():
target_key, pruned = map_teacher_to_student_key(
teacher_key, text_layer_map
)
if pruned:
pruned_text_keys.append(teacher_key)
continue
if target_key is None:
continue
if target_key not in student_shapes:
skipped_teacher_keys.append(
f"{teacher_key} -> {target_key} (no student tensor)"
)
continue
if tensor.shape != student_shapes[target_key]:
shape_mismatches.append(
f"{teacher_key} -> {target_key}: teacher {tuple(tensor.shape)} "
f"vs student {tuple(student_shapes[target_key])}"
)
continue
if torch_dtype is not None and tensor.dtype != torch_dtype:
tensor = tensor.to(dtype=torch_dtype)
t_size = tensor.numel() * tensor.element_size()
if current_size + t_size > max_shard_size_bytes and current_tensors:
flush_current()
current_tensors[target_key] = tensor
weight_map[target_key] = current_shard_name
current_size += t_size
loaded_keys.add(target_key)
total_params += tensor.numel()
total_size_bytes += t_size
found_in_shard += 1
del shard_state
if found_in_shard == 0:
print(
f"[info] no student tensors in shard '{shard_name}', skipping write."
)
gc.collect()
flush_current()
missing_keys = sorted((set(student_shapes.keys()) - loaded_keys) - EXPECTED_MISSING)
vision_indices = (
list(range(student_config.vision_config.num_hidden_layers))
if getattr(student_config, "vision_config", None)
else []
)
return (
text_indices,
vision_indices,
missing_keys,
skipped_teacher_keys,
pruned_text_keys,
shape_mismatches,
None,
weight_map,
total_params,
total_size_bytes,
)
# ---------- main shrinking logic ----------
def build_student_from_teacher(
teacher_repo: str,
teacher_subfolder: str,
num_text_layers: int,
torch_dtype: Optional[torch.dtype],
) -> Tuple[
Mistral3ForConditionalGeneration,
List[int],
List[int],
List[str],
List[str],
List[str],
List[str],
Optional[str],
]:
"""
Load teacher config, create a smaller student, and stream-copy weights
so we never materialize the full teacher model in memory.
Vision depth is kept intact; only text layers are reduced.
"""
teacher_config = AutoConfig.from_pretrained(
teacher_repo, subfolder=teacher_subfolder
)
if not hasattr(teacher_config, "text_config") or teacher_config.text_config is None:
raise RuntimeError(
"Teacher config has no text_config — not a Mistral3 multimodal model?"
)
text_total = teacher_config.text_config.num_hidden_layers
vision_total = (
teacher_config.vision_config.num_hidden_layers
if getattr(teacher_config, "vision_config", None) is not None
else 0
)
print(f"[info] Teacher text layers: {text_total}")
print(f"[info] Teacher vision layers: {vision_total}")
if num_text_layers > text_total:
raise ValueError(
f"Requested num_text_layers={num_text_layers} > teacher text_total={text_total}"
)
has_quant_cfg = getattr(teacher_config, "quantization_config", None) is not None
if has_quant_cfg:
raise RuntimeError(
"Teacher appears to be quantized (quantization_config present). "
"Please use a dense teacher, e.g. 'black-forest-labs/FLUX.2-dev' "
"with subfolder='text_encoder', not the NF4-quantized repo."
)
# --- choose which teacher layers to keep ---
text_indices = choose_layer_indices(text_total, num_text_layers)
print(f"[info] Selected teacher text layers -> student: {text_indices}")
# Vision is kept at full depth; no subsetting.
vision_indices = list(range(vision_total)) if vision_total > 0 else []
# --- shrink config ---
student_config = deepcopy(teacher_config)
student_config.text_config.num_hidden_layers = num_text_layers
if vision_total > 0:
student_config.vision_config.num_hidden_layers = vision_total
student_config.selected_text_layers = text_indices # type: ignore[attr-defined]
if vision_indices:
student_config.selected_vision_layers = vision_indices # type: ignore[attr-defined]
# --- init student model ---
print("[info] Initializing student model from shrunk config...")
student = Mistral3ForConditionalGeneration(student_config)
if torch_dtype is not None:
student.to(dtype=torch_dtype)
# --- stream teacher weights into student ---
checkpoint_files, checkpoint_index_filename = resolve_checkpoint_files(
teacher_repo, teacher_subfolder
)
print(
f"[info] Found {len(checkpoint_files)} file(s) to stream"
+ (
f" via index '{checkpoint_index_filename}'."
if checkpoint_index_filename
else "."
)
)
(
loaded_keys,
skipped_teacher_keys,
pruned_text_keys,
shape_mismatches,
) = stream_copy_teacher_to_student(
checkpoint_files=checkpoint_files,
student=student,
text_indices=text_indices,
torch_dtype=torch_dtype,
)
student_param_keys = {k for k, _ in student.named_parameters()}
student_buffer_keys = {k for k, _ in student.named_buffers()}
student_all_keys = student_param_keys | student_buffer_keys
missing_keys = sorted((student_all_keys - loaded_keys) - EXPECTED_MISSING)
return (
student,
text_indices,
vision_indices,
missing_keys,
skipped_teacher_keys,
pruned_text_keys,
shape_mismatches,
checkpoint_index_filename,
)
# ---------- CLI ----------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Shrink FLUX.2 Mistral3 text_encoder into a 7-layer student."
)
parser.add_argument(
"--teacher-repo",
type=str,
default="black-forest-labs/FLUX.2-dev",
help=(
"Hugging Face repo id of FLUX.2. "
"Must contain a `text_encoder` subfolder. "
"Default: black-forest-labs/FLUX.2-dev"
),
)
parser.add_argument(
"--teacher-subfolder",
type=str,
default="text_encoder",
help="Subfolder with Mistral3 text encoder weights inside teacher repo.",
)
parser.add_argument(
"--output-dir",
type=str,
default=MODEL_NAME,
help="Where to save the student HF model (default: model name).",
)
parser.add_argument(
"--num-text-layers",
type=int,
default=7,
help="Number of transformer layers for text path in student.",
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["auto", "bfloat16", "float16", "float32"],
help="Target dtype for student weights. 'auto' keeps teacher dtype.",
)
parser.add_argument(
"--max-shard-size-gb",
type=float,
default=5.0,
help="Max shard size (GB) when writing safetensors (stream-to-disk).",
)
parser.add_argument(
"--stream-to-disk",
action="store_true",
help=(
"Copy tensors shard-by-shard straight to disk without holding the full student in RAM. "
"Produces safetensors shards and an index in output-dir."
),
)
return parser.parse_args()
def main():
args = parse_args()
torch_dtype: Optional[torch.dtype]
if args.dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif args.dtype == "float16":
torch_dtype = torch.float16
elif args.dtype == "float32":
torch_dtype = torch.float32
else: # auto
torch_dtype = None
print(
f"[info] Streaming teacher Mistral3 text encoder from "
f"{args.teacher_repo} (subfolder='{args.teacher_subfolder}') "
f"directly into the student..."
)
if args.stream_to_disk:
# --- config prep (same checks as in build_student_from_teacher) ---
teacher_config = AutoConfig.from_pretrained(
args.teacher_repo, subfolder=args.teacher_subfolder
)
if (
not hasattr(teacher_config, "text_config")
or teacher_config.text_config is None
):
raise RuntimeError(
"Teacher config has no text_config — not a Mistral3 multimodal model?"
)
text_total = teacher_config.text_config.num_hidden_layers
vision_total = (
teacher_config.vision_config.num_hidden_layers
if getattr(teacher_config, "vision_config", None) is not None
else 0
)
if args.num_text_layers > text_total:
raise ValueError(
f"Requested num_text_layers={args.num_text_layers} > teacher text_total={text_total}"
)
has_quant_cfg = getattr(teacher_config, "quantization_config", None) is not None
if has_quant_cfg:
raise RuntimeError(
"Teacher appears to be quantized (quantization_config present). "
"Use a dense teacher, e.g. black-forest-labs/FLUX.2-dev with subfolder='text_encoder'."
)
text_idx = choose_layer_indices(text_total, args.num_text_layers)
vision_idx = list(range(vision_total)) if vision_total > 0 else []
print(f"[info] Teacher text layers: {text_total}")
print(f"[info] Teacher vision layers: {vision_total}")
print(f"[info] Selected teacher text layers -> student: {text_idx}")
student_config = deepcopy(teacher_config)
student_config.text_config.num_hidden_layers = args.num_text_layers
if vision_total > 0:
student_config.vision_config.num_hidden_layers = vision_total
student_config.selected_text_layers = text_idx # type: ignore[attr-defined]
if vision_idx:
student_config.selected_vision_layers = vision_idx # type: ignore[attr-defined]
os.makedirs(args.output_dir, exist_ok=True)
student_config.save_pretrained(args.output_dir)
try:
gen_cfg = GenerationConfig.from_pretrained(
args.teacher_repo, subfolder=args.teacher_subfolder
)
gen_cfg.save_pretrained(args.output_dir)
except Exception as e: # pragma: no cover - optional nicety
print(f"[warn] Could not save generation_config: {e}")
checkpoint_files, checkpoint_index_filename = resolve_checkpoint_files(
args.teacher_repo, args.teacher_subfolder
)
print(
f"[info] Found {len(checkpoint_files)} file(s) to stream"
+ (
f" via index '{checkpoint_index_filename}'."
if checkpoint_index_filename
else "."
)
)
(
_text_idx,
_vision_idx,
missing_keys,
skipped_teacher_keys,
pruned_text_keys,
shape_mismatches,
_ckpt_idx,
weight_map,
total_params,
total_size_bytes,
) = stream_copy_to_disk(
checkpoint_files=checkpoint_files,
text_indices=text_idx,
torch_dtype=torch_dtype,
student_config=student_config,
output_dir=args.output_dir,
max_shard_size_bytes=int(args.max_shard_size_gb * (1024**3)),
)
index_path = os.path.join(args.output_dir, "model.safetensors.index.json")
with open(index_path, "w", encoding="utf-8") as f:
json.dump(
{
"metadata": {
"total_parameters": total_params,
"total_size": total_size_bytes,
"dtype": "auto (teacher)"
if torch_dtype is None
else str(torch_dtype),
},
"weight_map": weight_map,
},
f,
indent=2,
ensure_ascii=False,
)
print(f"[info] Wrote safetensors index to {index_path}")
if pruned_text_keys:
print(f"[info] Dropped {len(pruned_text_keys)} teacher text-layer tensors.")
if skipped_teacher_keys:
print(
f"[warn] {len(skipped_teacher_keys)} teacher tensors were not used "
"because no matching student key was found."
)
if shape_mismatches:
print(
f"[warn] {len(shape_mismatches)} tensors skipped due to shape mismatch."
)
if missing_keys:
print(f"[warn] Student is still missing {len(missing_keys)} tensors.")
meta = {
"teacher_repo": args.teacher_repo,
"teacher_subfolder": args.teacher_subfolder,
"dtype": args.dtype,
"model_name": MODEL_NAME,
"num_text_layers_student": args.num_text_layers,
"num_vision_layers_student": len(vision_idx),
"selected_text_layers_from_teacher": text_idx,
"selected_vision_layers_from_teacher": vision_idx,
"missing_keys_on_load": missing_keys,
"unexpected_keys_on_load": skipped_teacher_keys,
"skipped_teacher_keys": skipped_teacher_keys,
"dropped_teacher_text_keys": pruned_text_keys,
"shape_mismatches": shape_mismatches,
"checkpoint_index_filename": checkpoint_index_filename,
"stream_to_disk": True,
"weight_map": weight_map,
"total_parameters": total_params,
"total_size_bytes": total_size_bytes,
}
meta_path = os.path.join(args.output_dir, "distillation_meta.json")
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2, ensure_ascii=False)
print(f"[info] Wrote metadata to {meta_path}")
write_readme(
output_dir=args.output_dir,
text_layers=args.num_text_layers,
vision_layers=len(vision_idx),
model_name=MODEL_NAME,
)
print(
"[done] Student weights streamed to disk. You can now load from output-dir without ever "
"holding the full student in RAM during conversion."
)
else:
(
student,
text_idx,
vision_idx,
missing_keys,
skipped_teacher_keys,
pruned_text_keys,
shape_mismatches,
checkpoint_index_filename,
) = build_student_from_teacher(
teacher_repo=args.teacher_repo,
teacher_subfolder=args.teacher_subfolder,
num_text_layers=args.num_text_layers,
torch_dtype=torch_dtype,
)
if pruned_text_keys:
print(f"[info] Dropped {len(pruned_text_keys)} teacher text-layer tensors.")
if skipped_teacher_keys:
print(
f"[warn] {len(skipped_teacher_keys)} teacher tensors were not used "
"because no matching student key was found."
)
if shape_mismatches:
print(
f"[warn] {len(shape_mismatches)} tensors skipped due to shape mismatch."
)
if missing_keys:
print(f"[warn] Student is still missing {len(missing_keys)} tensors.")
print(f"[info] Saving student model to '{args.output_dir}'...")
os.makedirs(args.output_dir, exist_ok=True)
student.save_pretrained(
args.output_dir,
safe_serialization=True,
max_shard_size=f"{args.max_shard_size_gb}GB",
)
# Optionally: dump some metadata so it's easy to inspect mapping later
meta = {
"teacher_repo": args.teacher_repo,
"teacher_subfolder": args.teacher_subfolder,
"dtype": args.dtype,
"model_name": MODEL_NAME,
"num_text_layers_student": args.num_text_layers,
"num_vision_layers_student": len(vision_idx),
"selected_text_layers_from_teacher": text_idx,
"selected_vision_layers_from_teacher": vision_idx,
"missing_keys_on_load": missing_keys,
"unexpected_keys_on_load": skipped_teacher_keys,
"skipped_teacher_keys": skipped_teacher_keys,
"dropped_teacher_text_keys": pruned_text_keys,
"shape_mismatches": shape_mismatches,
"checkpoint_index_filename": checkpoint_index_filename,
"stream_to_disk": False,
}
meta_path = os.path.join(args.output_dir, "distillation_meta.json")
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2, ensure_ascii=False)
print(f"[info] Wrote metadata to {meta_path}")
write_readme(
output_dir=args.output_dir,
text_layers=args.num_text_layers,
vision_layers=len(vision_idx),
model_name=MODEL_NAME,
)
print(
"[done] Student text_encoder is ready. Next step: distillation loss & training loop."
)
if __name__ == "__main__":
main()