| |
| """ |
| Shrink FLUX.2 text encoder (Mistral3) down to a 7-layer student. |
| |
| - Teacher: Mistral3ForConditionalGeneration from FLUX.2 text_encoder |
| (by default: black-forest-labs/FLUX.2-dev, subfolder="text_encoder") |
| |
| - Student: same architecture family (Mistral3), but with: |
| * num_text_layers = 7 (from 40) |
| * vision depth kept intact (24 -> 24) |
| |
| Weights are streamed shard-by-shard so the full teacher never lives in memory. |
| Text layers are mapped from strategically chosen teacher layers (spread across depth). |
| """ |
|
|
| import argparse |
| import gc |
| import json |
| import os |
| from copy import deepcopy |
| from typing import Dict, List, Optional, Set, Tuple |
|
|
| import torch |
| from huggingface_hub import HfFolder, file_exists, hf_hub_download |
| from safetensors.torch import safe_open |
| from transformers import AutoConfig, GenerationConfig, Mistral3ForConditionalGeneration |
| from transformers.modeling_utils import init_empty_weights |
| from transformers.utils.hub import get_checkpoint_shard_files |
|
|
| |
| MODEL_NAME = "FLUX2-TE-Trimmed7L-Research" |
|
|
| EXPECTED_MISSING = { |
| "model.language_model.rotary_emb.inv_freq", |
| "model.vision_tower.patch_positional_embedding.inv_freq", |
| } |
|
|
|
|
| def write_readme( |
| output_dir: str, |
| text_layers: int, |
| vision_layers: int, |
| model_name: str, |
| ) -> None: |
| """Write README with concise technical details matching the shipped artifact.""" |
|
|
| lines = [ |
| f"# {model_name}", |
| "", |
| "**WARNING: Experimental research artifact. Do NOT use in production.**", |
| "", |
| "## What this is", |
| f"- Model name: {model_name}", |
| f"- Text transformer depth: {text_layers} layers", |
| f"- Vision tower depth: {vision_layers} layers (unchanged)", |
| "", |
| "## Caveats", |
| "- Intended for experimentation and further distillation, not end-user deployment.", |
| "- Rotary and patch positional inv_freq buffers are created at load time (not stored).", |
| "", |
| "## How it was built", |
| f"- Shrank teacher text path to {text_layers} layers (evenly spaced) while keeping vision depth intact.", |
| "- Stream-copied only needed tensors; other text layers were pruned.", |
| ] |
|
|
| readme_path = os.path.join(output_dir, "README.md") |
| with open(readme_path, "w", encoding="utf-8") as f: |
| f.write("\n".join(lines) + "\n") |
|
|
|
|
| |
|
|
|
|
| def choose_layer_indices(total_layers: int, num_keep: int) -> List[int]: |
| """ |
| Choose num_keep layer indices from [0, total_layers-1], roughly evenly spaced. |
| Always includes first (0) and last (total_layers-1), assuming num_keep >= 2. |
| """ |
| if num_keep <= 0: |
| raise ValueError("num_keep must be > 0") |
| if num_keep > total_layers: |
| raise ValueError( |
| f"num_keep ({num_keep}) cannot exceed total_layers ({total_layers})" |
| ) |
| if num_keep == total_layers: |
| return list(range(total_layers)) |
| if num_keep == 1: |
| return [total_layers - 1] |
|
|
| indices = set() |
| for i in range(num_keep): |
| |
| pos = int(round(i * (total_layers - 1) / (num_keep - 1))) |
| indices.add(max(0, min(total_layers - 1, pos))) |
|
|
| |
| while len(indices) < num_keep: |
| for i in range(total_layers): |
| if i not in indices: |
| indices.add(i) |
| if len(indices) == num_keep: |
| break |
|
|
| return sorted(indices) |
|
|
|
|
| def map_teacher_to_student_key( |
| teacher_key: str, text_layer_map: Dict[int, int] |
| ) -> Tuple[Optional[str], bool]: |
| """ |
| Map a teacher tensor name to the corresponding student tensor name. |
| |
| Returns: |
| (target_key, pruned_text_layer) |
| - target_key is None when the teacher key belongs to a text layer that |
| is not selected for the student. |
| - pruned_text_layer is True when the skip reason is text-layer pruning. |
| """ |
| |
| text_prefix = "language_model.model.layers." |
| if teacher_key.startswith(text_prefix): |
| remainder = teacher_key[len(text_prefix) :] |
| if "." not in remainder: |
| return None, False |
| layer_str, suffix = remainder.split(".", 1) |
| try: |
| old_idx = int(layer_str) |
| except ValueError: |
| return None, False |
|
|
| if old_idx not in text_layer_map: |
| return None, True |
| new_idx = text_layer_map[old_idx] |
| return f"model.language_model.layers.{new_idx}.{suffix}", False |
|
|
| |
| lm_prefix = "language_model.model." |
| if teacher_key.startswith(lm_prefix): |
| remainder = teacher_key[len(lm_prefix) :] |
| return f"model.language_model.{remainder}", False |
|
|
| |
| lm_head_prefix = "language_model.lm_head." |
| if teacher_key.startswith(lm_head_prefix): |
| remainder = teacher_key[len(lm_head_prefix) :] |
| return f"lm_head.{remainder}", False |
|
|
| |
| if teacher_key.startswith("vision_tower."): |
| return f"model.{teacher_key}", False |
|
|
| |
| if teacher_key.startswith("multi_modal_projector."): |
| return f"model.{teacher_key}", False |
|
|
| return teacher_key, False |
|
|
|
|
| def resolve_checkpoint_files( |
| teacher_repo: str, teacher_subfolder: str |
| ) -> Tuple[List[str], Optional[str]]: |
| """ |
| Find checkpoint shard files (or a single weight file) for a HF repo/subfolder. |
| Prefers safetensors index, then bin index, then single-file weights. |
| """ |
| token = os.environ.get("HF_TOKEN") or HfFolder.get_token() |
| path_prefix = f"{teacher_subfolder}/" if teacher_subfolder else "" |
| index_candidates = [ |
| "model.safetensors.index.json", |
| "pytorch_model.bin.index.json", |
| ] |
| fallback_errors = [] |
|
|
| for cand in index_candidates: |
| repo_path = f"{path_prefix}{cand}" |
| try: |
| if not file_exists(teacher_repo, repo_path, token=token): |
| continue |
|
|
| |
| try: |
| ckpt_files, _ = get_checkpoint_shard_files( |
| teacher_repo, |
| index_filename=cand, |
| subfolder=teacher_subfolder, |
| token=token, |
| ) |
| return ckpt_files, cand |
| except Exception as e: |
| fallback_errors.append( |
| f"{cand} via transformers: {type(e).__name__}: {e}" |
| ) |
|
|
| |
| try: |
| index_path = hf_hub_download( |
| teacher_repo, |
| filename=cand, |
| subfolder=teacher_subfolder, |
| token=token, |
| ) |
| with open(index_path, "r", encoding="utf-8") as f: |
| index_data = json.load(f) |
|
|
| shard_filenames = sorted(set(index_data.get("weight_map", {}).values())) |
| if not shard_filenames: |
| raise RuntimeError("index file contained no weight_map entries") |
|
|
| shard_paths = [ |
| hf_hub_download( |
| teacher_repo, |
| filename=shard, |
| subfolder=teacher_subfolder, |
| token=token, |
| ) |
| for shard in shard_filenames |
| ] |
|
|
| return shard_paths, cand |
| except Exception as e: |
| fallback_errors.append( |
| f"{cand} via hf_hub_download: {type(e).__name__}: {e}" |
| ) |
| except Exception as e: |
| fallback_errors.append(f"HEAD {repo_path}: {type(e).__name__}: {e}") |
|
|
| single_candidates = ["model.safetensors", "pytorch_model.bin"] |
| for cand in single_candidates: |
| repo_path = f"{path_prefix}{cand}" |
| try: |
| if file_exists(teacher_repo, repo_path, token=token): |
| local_path = hf_hub_download( |
| teacher_repo, |
| filename=cand, |
| subfolder=teacher_subfolder, |
| token=token, |
| ) |
| return [local_path], None |
| except Exception: |
| continue |
|
|
| err_detail = ( |
| "; ".join(fallback_errors) |
| if fallback_errors |
| else "no matching index or single-file weights found" |
| ) |
| raise RuntimeError( |
| "Could not locate checkpoint files in the specified repo/subfolder. " |
| f"Details: {err_detail}" |
| ) |
|
|
|
|
| def stream_copy_teacher_to_student( |
| checkpoint_files: List[str], |
| student: Mistral3ForConditionalGeneration, |
| text_indices: List[int], |
| torch_dtype: Optional[torch.dtype], |
| ) -> Tuple[Set[str], List[str], List[str], List[str]]: |
| """ |
| Load teacher weights shard-by-shard and copy only what the student needs. |
| |
| Returns: |
| loaded_keys: names successfully copied into the student |
| skipped_teacher_keys: teacher keys that exist but are not present in the student |
| pruned_text_keys: teacher text-layer keys that were intentionally dropped |
| shape_mismatches: keys that were found but skipped because of shape mismatch |
| """ |
| text_layer_map = {old_idx: new_idx for new_idx, old_idx in enumerate(text_indices)} |
|
|
| target_tensors: Dict[str, torch.Tensor] = {} |
| target_tensors.update({k: v for k, v in student.named_parameters()}) |
| target_tensors.update({k: v for k, v in student.named_buffers()}) |
|
|
| loaded_keys = set() |
| skipped_teacher_keys: List[str] = [] |
| pruned_text_keys: List[str] = [] |
| shape_mismatches: List[str] = [] |
|
|
| auto_dtype = torch_dtype is None |
| chosen_auto_dtype: Optional[torch.dtype] = None |
|
|
| for shard_path in checkpoint_files: |
| shard_name = os.path.basename(shard_path) |
| print(f"[info] Streaming shard '{shard_name}'...") |
| loaded_in_shard = 0 |
|
|
| if shard_path.endswith(".safetensors"): |
| with safe_open(shard_path, framework="pt", device="cpu") as f: |
| for teacher_key in f.keys(): |
| target_key, pruned = map_teacher_to_student_key( |
| teacher_key, text_layer_map |
| ) |
| if pruned: |
| pruned_text_keys.append(teacher_key) |
| continue |
| if target_key is None: |
| continue |
|
|
| if target_key not in target_tensors: |
| skipped_teacher_keys.append(f"{teacher_key} -> {target_key}") |
| continue |
|
|
| tensor = f.get_tensor(teacher_key) |
| target_tensor = target_tensors[target_key] |
|
|
| if target_tensor.shape != tensor.shape: |
| shape_mismatches.append( |
| f"{teacher_key} -> {target_key}: " |
| f"teacher {tuple(tensor.shape)} vs student {tuple(target_tensor.shape)}" |
| ) |
| continue |
|
|
| if auto_dtype and chosen_auto_dtype is None: |
| chosen_auto_dtype = tensor.dtype |
| student.to(dtype=chosen_auto_dtype) |
| target_tensor = target_tensors[target_key] |
| if torch_dtype is not None and tensor.dtype != torch_dtype: |
| tensor = tensor.to(dtype=torch_dtype) |
|
|
| with torch.no_grad(): |
| target_tensor.copy_(tensor) |
| loaded_keys.add(target_key) |
| loaded_in_shard += 1 |
| else: |
| |
| shard_state = torch.load(shard_path, map_location="cpu") |
| for teacher_key, tensor in shard_state.items(): |
| target_key, pruned = map_teacher_to_student_key( |
| teacher_key, text_layer_map |
| ) |
| if pruned: |
| pruned_text_keys.append(teacher_key) |
| continue |
| if target_key is None: |
| continue |
|
|
| if target_key not in target_tensors: |
| skipped_teacher_keys.append(f"{teacher_key} -> {target_key}") |
| continue |
|
|
| target_tensor = target_tensors[target_key] |
|
|
| if target_tensor.shape != tensor.shape: |
| shape_mismatches.append( |
| f"{teacher_key} -> {target_key}: " |
| f"teacher {tuple(tensor.shape)} vs student {tuple(target_tensor.shape)}" |
| ) |
| continue |
|
|
| if auto_dtype and chosen_auto_dtype is None: |
| chosen_auto_dtype = tensor.dtype |
| student.to(dtype=chosen_auto_dtype) |
| target_tensor = target_tensors[target_key] |
| if torch_dtype is not None and tensor.dtype != torch_dtype: |
| tensor = tensor.to(dtype=torch_dtype) |
|
|
| with torch.no_grad(): |
| target_tensor.copy_(tensor) |
| loaded_keys.add(target_key) |
| loaded_in_shard += 1 |
| del shard_state |
|
|
| gc.collect() |
| print(f"[info] Loaded {loaded_in_shard} tensors from '{shard_name}'.") |
|
|
| return loaded_keys, skipped_teacher_keys, pruned_text_keys, shape_mismatches |
|
|
|
|
| def stream_copy_to_disk( |
| checkpoint_files: List[str], |
| text_indices: List[int], |
| torch_dtype: Optional[torch.dtype], |
| student_config: AutoConfig, |
| output_dir: str, |
| max_shard_size_bytes: int, |
| ) -> Tuple[ |
| List[int], |
| List[int], |
| List[str], |
| List[str], |
| List[str], |
| Optional[str], |
| Dict[str, str], |
| int, |
| int, |
| ]: |
| """ |
| Stream teacher shards and write student shards directly to disk to reduce RAM usage. |
| |
| Returns tuple matching the info parts of build_student_from_teacher, plus |
| weight_map, total_params, total_size_bytes for the generated index file. |
| """ |
|
|
| |
| with init_empty_weights(): |
| meta_student = Mistral3ForConditionalGeneration(student_config) |
|
|
| student_param_shapes = {k: p.shape for k, p in meta_student.named_parameters()} |
| student_buffer_shapes = {k: b.shape for k, b in meta_student.named_buffers()} |
| student_shapes = {**student_param_shapes, **student_buffer_shapes} |
|
|
| text_layer_map = {old_idx: new_idx for new_idx, old_idx in enumerate(text_indices)} |
|
|
| loaded_keys: Set[str] = set() |
| skipped_teacher_keys: List[str] = [] |
| pruned_text_keys: List[str] = [] |
| shape_mismatches: List[str] = [] |
| weight_map: Dict[str, str] = {} |
|
|
| total_params = 0 |
| total_size_bytes = 0 |
|
|
| os.makedirs(output_dir, exist_ok=True) |
|
|
| current_tensors: Dict[str, torch.Tensor] = {} |
| current_size = 0 |
| shard_idx = 1 |
| current_shard_name = f"model-shard-{shard_idx:05d}.safetensors" |
|
|
| def flush_current() -> None: |
| nonlocal shard_idx, current_tensors, current_size, current_shard_name |
| if not current_tensors: |
| return |
| out_path = os.path.join(output_dir, current_shard_name) |
| print( |
| f"[info] writing {len(current_tensors)} tensors to '{out_path}' (size ~{current_size / 1e9:.2f} GB)" |
| ) |
| from safetensors.torch import save_file |
|
|
| save_file(current_tensors, out_path) |
| shard_idx += 1 |
| current_shard_name = f"model-shard-{shard_idx:05d}.safetensors" |
| current_tensors = {} |
| current_size = 0 |
|
|
| for shard_path in checkpoint_files: |
| shard_name = os.path.basename(shard_path) |
| print(f"[info] Streaming shard '{shard_name}' -> disk...") |
| found_in_shard = 0 |
|
|
| if shard_path.endswith(".safetensors"): |
| with safe_open(shard_path, framework="pt", device="cpu") as f: |
| for teacher_key in f.keys(): |
| target_key, pruned = map_teacher_to_student_key( |
| teacher_key, text_layer_map |
| ) |
| if pruned: |
| pruned_text_keys.append(teacher_key) |
| continue |
| if target_key is None: |
| continue |
|
|
| if target_key not in student_shapes: |
| skipped_teacher_keys.append( |
| f"{teacher_key} -> {target_key} (no student tensor)" |
| ) |
| continue |
|
|
| tensor = f.get_tensor(teacher_key) |
| if tensor.shape != student_shapes[target_key]: |
| shape_mismatches.append( |
| f"{teacher_key} -> {target_key}: teacher {tuple(tensor.shape)} " |
| f"vs student {tuple(student_shapes[target_key])}" |
| ) |
| continue |
|
|
| if torch_dtype is not None and tensor.dtype != torch_dtype: |
| tensor = tensor.to(dtype=torch_dtype) |
|
|
| t_size = tensor.numel() * tensor.element_size() |
| if current_size + t_size > max_shard_size_bytes and current_tensors: |
| flush_current() |
| current_tensors[target_key] = tensor |
| weight_map[target_key] = current_shard_name |
| current_size += t_size |
| loaded_keys.add(target_key) |
| total_params += tensor.numel() |
| total_size_bytes += t_size |
| found_in_shard += 1 |
| else: |
| shard_state = torch.load(shard_path, map_location="cpu") |
| for teacher_key, tensor in shard_state.items(): |
| target_key, pruned = map_teacher_to_student_key( |
| teacher_key, text_layer_map |
| ) |
| if pruned: |
| pruned_text_keys.append(teacher_key) |
| continue |
| if target_key is None: |
| continue |
|
|
| if target_key not in student_shapes: |
| skipped_teacher_keys.append( |
| f"{teacher_key} -> {target_key} (no student tensor)" |
| ) |
| continue |
|
|
| if tensor.shape != student_shapes[target_key]: |
| shape_mismatches.append( |
| f"{teacher_key} -> {target_key}: teacher {tuple(tensor.shape)} " |
| f"vs student {tuple(student_shapes[target_key])}" |
| ) |
| continue |
|
|
| if torch_dtype is not None and tensor.dtype != torch_dtype: |
| tensor = tensor.to(dtype=torch_dtype) |
|
|
| t_size = tensor.numel() * tensor.element_size() |
| if current_size + t_size > max_shard_size_bytes and current_tensors: |
| flush_current() |
| current_tensors[target_key] = tensor |
| weight_map[target_key] = current_shard_name |
| current_size += t_size |
| loaded_keys.add(target_key) |
| total_params += tensor.numel() |
| total_size_bytes += t_size |
| found_in_shard += 1 |
| del shard_state |
|
|
| if found_in_shard == 0: |
| print( |
| f"[info] no student tensors in shard '{shard_name}', skipping write." |
| ) |
|
|
| gc.collect() |
|
|
| flush_current() |
|
|
| missing_keys = sorted((set(student_shapes.keys()) - loaded_keys) - EXPECTED_MISSING) |
|
|
| vision_indices = ( |
| list(range(student_config.vision_config.num_hidden_layers)) |
| if getattr(student_config, "vision_config", None) |
| else [] |
| ) |
|
|
| return ( |
| text_indices, |
| vision_indices, |
| missing_keys, |
| skipped_teacher_keys, |
| pruned_text_keys, |
| shape_mismatches, |
| None, |
| weight_map, |
| total_params, |
| total_size_bytes, |
| ) |
|
|
|
|
| |
|
|
|
|
| def build_student_from_teacher( |
| teacher_repo: str, |
| teacher_subfolder: str, |
| num_text_layers: int, |
| torch_dtype: Optional[torch.dtype], |
| ) -> Tuple[ |
| Mistral3ForConditionalGeneration, |
| List[int], |
| List[int], |
| List[str], |
| List[str], |
| List[str], |
| List[str], |
| Optional[str], |
| ]: |
| """ |
| Load teacher config, create a smaller student, and stream-copy weights |
| so we never materialize the full teacher model in memory. |
| Vision depth is kept intact; only text layers are reduced. |
| """ |
| teacher_config = AutoConfig.from_pretrained( |
| teacher_repo, subfolder=teacher_subfolder |
| ) |
|
|
| if not hasattr(teacher_config, "text_config") or teacher_config.text_config is None: |
| raise RuntimeError( |
| "Teacher config has no text_config — not a Mistral3 multimodal model?" |
| ) |
|
|
| text_total = teacher_config.text_config.num_hidden_layers |
| vision_total = ( |
| teacher_config.vision_config.num_hidden_layers |
| if getattr(teacher_config, "vision_config", None) is not None |
| else 0 |
| ) |
|
|
| print(f"[info] Teacher text layers: {text_total}") |
| print(f"[info] Teacher vision layers: {vision_total}") |
|
|
| if num_text_layers > text_total: |
| raise ValueError( |
| f"Requested num_text_layers={num_text_layers} > teacher text_total={text_total}" |
| ) |
|
|
| has_quant_cfg = getattr(teacher_config, "quantization_config", None) is not None |
| if has_quant_cfg: |
| raise RuntimeError( |
| "Teacher appears to be quantized (quantization_config present). " |
| "Please use a dense teacher, e.g. 'black-forest-labs/FLUX.2-dev' " |
| "with subfolder='text_encoder', not the NF4-quantized repo." |
| ) |
|
|
| |
|
|
| text_indices = choose_layer_indices(text_total, num_text_layers) |
| print(f"[info] Selected teacher text layers -> student: {text_indices}") |
|
|
| |
| vision_indices = list(range(vision_total)) if vision_total > 0 else [] |
|
|
| |
|
|
| student_config = deepcopy(teacher_config) |
| student_config.text_config.num_hidden_layers = num_text_layers |
| if vision_total > 0: |
| student_config.vision_config.num_hidden_layers = vision_total |
|
|
| student_config.selected_text_layers = text_indices |
| if vision_indices: |
| student_config.selected_vision_layers = vision_indices |
|
|
| |
|
|
| print("[info] Initializing student model from shrunk config...") |
| student = Mistral3ForConditionalGeneration(student_config) |
| if torch_dtype is not None: |
| student.to(dtype=torch_dtype) |
|
|
| |
|
|
| checkpoint_files, checkpoint_index_filename = resolve_checkpoint_files( |
| teacher_repo, teacher_subfolder |
| ) |
| print( |
| f"[info] Found {len(checkpoint_files)} file(s) to stream" |
| + ( |
| f" via index '{checkpoint_index_filename}'." |
| if checkpoint_index_filename |
| else "." |
| ) |
| ) |
|
|
| ( |
| loaded_keys, |
| skipped_teacher_keys, |
| pruned_text_keys, |
| shape_mismatches, |
| ) = stream_copy_teacher_to_student( |
| checkpoint_files=checkpoint_files, |
| student=student, |
| text_indices=text_indices, |
| torch_dtype=torch_dtype, |
| ) |
|
|
| student_param_keys = {k for k, _ in student.named_parameters()} |
| student_buffer_keys = {k for k, _ in student.named_buffers()} |
| student_all_keys = student_param_keys | student_buffer_keys |
| missing_keys = sorted((student_all_keys - loaded_keys) - EXPECTED_MISSING) |
|
|
| return ( |
| student, |
| text_indices, |
| vision_indices, |
| missing_keys, |
| skipped_teacher_keys, |
| pruned_text_keys, |
| shape_mismatches, |
| checkpoint_index_filename, |
| ) |
|
|
|
|
| |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Shrink FLUX.2 Mistral3 text_encoder into a 7-layer student." |
| ) |
|
|
| parser.add_argument( |
| "--teacher-repo", |
| type=str, |
| default="black-forest-labs/FLUX.2-dev", |
| help=( |
| "Hugging Face repo id of FLUX.2. " |
| "Must contain a `text_encoder` subfolder. " |
| "Default: black-forest-labs/FLUX.2-dev" |
| ), |
| ) |
| parser.add_argument( |
| "--teacher-subfolder", |
| type=str, |
| default="text_encoder", |
| help="Subfolder with Mistral3 text encoder weights inside teacher repo.", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default=MODEL_NAME, |
| help="Where to save the student HF model (default: model name).", |
| ) |
| parser.add_argument( |
| "--num-text-layers", |
| type=int, |
| default=7, |
| help="Number of transformer layers for text path in student.", |
| ) |
| parser.add_argument( |
| "--dtype", |
| type=str, |
| default="auto", |
| choices=["auto", "bfloat16", "float16", "float32"], |
| help="Target dtype for student weights. 'auto' keeps teacher dtype.", |
| ) |
| parser.add_argument( |
| "--max-shard-size-gb", |
| type=float, |
| default=5.0, |
| help="Max shard size (GB) when writing safetensors (stream-to-disk).", |
| ) |
| parser.add_argument( |
| "--stream-to-disk", |
| action="store_true", |
| help=( |
| "Copy tensors shard-by-shard straight to disk without holding the full student in RAM. " |
| "Produces safetensors shards and an index in output-dir." |
| ), |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| torch_dtype: Optional[torch.dtype] |
| if args.dtype == "bfloat16": |
| torch_dtype = torch.bfloat16 |
| elif args.dtype == "float16": |
| torch_dtype = torch.float16 |
| elif args.dtype == "float32": |
| torch_dtype = torch.float32 |
| else: |
| torch_dtype = None |
|
|
| print( |
| f"[info] Streaming teacher Mistral3 text encoder from " |
| f"{args.teacher_repo} (subfolder='{args.teacher_subfolder}') " |
| f"directly into the student..." |
| ) |
|
|
| if args.stream_to_disk: |
| |
| teacher_config = AutoConfig.from_pretrained( |
| args.teacher_repo, subfolder=args.teacher_subfolder |
| ) |
|
|
| if ( |
| not hasattr(teacher_config, "text_config") |
| or teacher_config.text_config is None |
| ): |
| raise RuntimeError( |
| "Teacher config has no text_config — not a Mistral3 multimodal model?" |
| ) |
|
|
| text_total = teacher_config.text_config.num_hidden_layers |
| vision_total = ( |
| teacher_config.vision_config.num_hidden_layers |
| if getattr(teacher_config, "vision_config", None) is not None |
| else 0 |
| ) |
|
|
| if args.num_text_layers > text_total: |
| raise ValueError( |
| f"Requested num_text_layers={args.num_text_layers} > teacher text_total={text_total}" |
| ) |
|
|
| has_quant_cfg = getattr(teacher_config, "quantization_config", None) is not None |
| if has_quant_cfg: |
| raise RuntimeError( |
| "Teacher appears to be quantized (quantization_config present). " |
| "Use a dense teacher, e.g. black-forest-labs/FLUX.2-dev with subfolder='text_encoder'." |
| ) |
|
|
| text_idx = choose_layer_indices(text_total, args.num_text_layers) |
| vision_idx = list(range(vision_total)) if vision_total > 0 else [] |
|
|
| print(f"[info] Teacher text layers: {text_total}") |
| print(f"[info] Teacher vision layers: {vision_total}") |
| print(f"[info] Selected teacher text layers -> student: {text_idx}") |
|
|
| student_config = deepcopy(teacher_config) |
| student_config.text_config.num_hidden_layers = args.num_text_layers |
| if vision_total > 0: |
| student_config.vision_config.num_hidden_layers = vision_total |
|
|
| student_config.selected_text_layers = text_idx |
| if vision_idx: |
| student_config.selected_vision_layers = vision_idx |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| student_config.save_pretrained(args.output_dir) |
| try: |
| gen_cfg = GenerationConfig.from_pretrained( |
| args.teacher_repo, subfolder=args.teacher_subfolder |
| ) |
| gen_cfg.save_pretrained(args.output_dir) |
| except Exception as e: |
| print(f"[warn] Could not save generation_config: {e}") |
|
|
| checkpoint_files, checkpoint_index_filename = resolve_checkpoint_files( |
| args.teacher_repo, args.teacher_subfolder |
| ) |
| print( |
| f"[info] Found {len(checkpoint_files)} file(s) to stream" |
| + ( |
| f" via index '{checkpoint_index_filename}'." |
| if checkpoint_index_filename |
| else "." |
| ) |
| ) |
|
|
| ( |
| _text_idx, |
| _vision_idx, |
| missing_keys, |
| skipped_teacher_keys, |
| pruned_text_keys, |
| shape_mismatches, |
| _ckpt_idx, |
| weight_map, |
| total_params, |
| total_size_bytes, |
| ) = stream_copy_to_disk( |
| checkpoint_files=checkpoint_files, |
| text_indices=text_idx, |
| torch_dtype=torch_dtype, |
| student_config=student_config, |
| output_dir=args.output_dir, |
| max_shard_size_bytes=int(args.max_shard_size_gb * (1024**3)), |
| ) |
|
|
| index_path = os.path.join(args.output_dir, "model.safetensors.index.json") |
| with open(index_path, "w", encoding="utf-8") as f: |
| json.dump( |
| { |
| "metadata": { |
| "total_parameters": total_params, |
| "total_size": total_size_bytes, |
| "dtype": "auto (teacher)" |
| if torch_dtype is None |
| else str(torch_dtype), |
| }, |
| "weight_map": weight_map, |
| }, |
| f, |
| indent=2, |
| ensure_ascii=False, |
| ) |
| print(f"[info] Wrote safetensors index to {index_path}") |
|
|
| if pruned_text_keys: |
| print(f"[info] Dropped {len(pruned_text_keys)} teacher text-layer tensors.") |
| if skipped_teacher_keys: |
| print( |
| f"[warn] {len(skipped_teacher_keys)} teacher tensors were not used " |
| "because no matching student key was found." |
| ) |
| if shape_mismatches: |
| print( |
| f"[warn] {len(shape_mismatches)} tensors skipped due to shape mismatch." |
| ) |
| if missing_keys: |
| print(f"[warn] Student is still missing {len(missing_keys)} tensors.") |
|
|
| meta = { |
| "teacher_repo": args.teacher_repo, |
| "teacher_subfolder": args.teacher_subfolder, |
| "dtype": args.dtype, |
| "model_name": MODEL_NAME, |
| "num_text_layers_student": args.num_text_layers, |
| "num_vision_layers_student": len(vision_idx), |
| "selected_text_layers_from_teacher": text_idx, |
| "selected_vision_layers_from_teacher": vision_idx, |
| "missing_keys_on_load": missing_keys, |
| "unexpected_keys_on_load": skipped_teacher_keys, |
| "skipped_teacher_keys": skipped_teacher_keys, |
| "dropped_teacher_text_keys": pruned_text_keys, |
| "shape_mismatches": shape_mismatches, |
| "checkpoint_index_filename": checkpoint_index_filename, |
| "stream_to_disk": True, |
| "weight_map": weight_map, |
| "total_parameters": total_params, |
| "total_size_bytes": total_size_bytes, |
| } |
| meta_path = os.path.join(args.output_dir, "distillation_meta.json") |
| with open(meta_path, "w", encoding="utf-8") as f: |
| json.dump(meta, f, indent=2, ensure_ascii=False) |
| print(f"[info] Wrote metadata to {meta_path}") |
|
|
| write_readme( |
| output_dir=args.output_dir, |
| text_layers=args.num_text_layers, |
| vision_layers=len(vision_idx), |
| model_name=MODEL_NAME, |
| ) |
|
|
| print( |
| "[done] Student weights streamed to disk. You can now load from output-dir without ever " |
| "holding the full student in RAM during conversion." |
| ) |
|
|
| else: |
| ( |
| student, |
| text_idx, |
| vision_idx, |
| missing_keys, |
| skipped_teacher_keys, |
| pruned_text_keys, |
| shape_mismatches, |
| checkpoint_index_filename, |
| ) = build_student_from_teacher( |
| teacher_repo=args.teacher_repo, |
| teacher_subfolder=args.teacher_subfolder, |
| num_text_layers=args.num_text_layers, |
| torch_dtype=torch_dtype, |
| ) |
|
|
| if pruned_text_keys: |
| print(f"[info] Dropped {len(pruned_text_keys)} teacher text-layer tensors.") |
| if skipped_teacher_keys: |
| print( |
| f"[warn] {len(skipped_teacher_keys)} teacher tensors were not used " |
| "because no matching student key was found." |
| ) |
| if shape_mismatches: |
| print( |
| f"[warn] {len(shape_mismatches)} tensors skipped due to shape mismatch." |
| ) |
| if missing_keys: |
| print(f"[warn] Student is still missing {len(missing_keys)} tensors.") |
|
|
| print(f"[info] Saving student model to '{args.output_dir}'...") |
| os.makedirs(args.output_dir, exist_ok=True) |
| student.save_pretrained( |
| args.output_dir, |
| safe_serialization=True, |
| max_shard_size=f"{args.max_shard_size_gb}GB", |
| ) |
|
|
| |
| meta = { |
| "teacher_repo": args.teacher_repo, |
| "teacher_subfolder": args.teacher_subfolder, |
| "dtype": args.dtype, |
| "model_name": MODEL_NAME, |
| "num_text_layers_student": args.num_text_layers, |
| "num_vision_layers_student": len(vision_idx), |
| "selected_text_layers_from_teacher": text_idx, |
| "selected_vision_layers_from_teacher": vision_idx, |
| "missing_keys_on_load": missing_keys, |
| "unexpected_keys_on_load": skipped_teacher_keys, |
| "skipped_teacher_keys": skipped_teacher_keys, |
| "dropped_teacher_text_keys": pruned_text_keys, |
| "shape_mismatches": shape_mismatches, |
| "checkpoint_index_filename": checkpoint_index_filename, |
| "stream_to_disk": False, |
| } |
| meta_path = os.path.join(args.output_dir, "distillation_meta.json") |
| with open(meta_path, "w", encoding="utf-8") as f: |
| json.dump(meta, f, indent=2, ensure_ascii=False) |
| print(f"[info] Wrote metadata to {meta_path}") |
|
|
| write_readme( |
| output_dir=args.output_dir, |
| text_layers=args.num_text_layers, |
| vision_layers=len(vision_idx), |
| model_name=MODEL_NAME, |
| ) |
|
|
| print( |
| "[done] Student text_encoder is ready. Next step: distillation loss & training loop." |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|