Instructions to use aydin99/FLUX.2-klein-4B-int8 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use aydin99/FLUX.2-klein-4B-int8 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("aydin99/FLUX.2-klein-4B-int8", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps
- Draw Things
- DiffusionBee
| # Copyright 2024 The HuggingFace Team. All rights reserved. | |
| # Adapted for FLUX.2-klein by adding Flux2Transformer2DModel and Qwen3 support | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, List, Optional, Union | |
| from huggingface_hub import ModelHubMixin, snapshot_download | |
| from optimum.quanto import freeze, qtype, quantization_map, quantize, requantize, Optimizer | |
| from optimum.quanto.models import is_diffusers_available | |
| from diffusers.models.model_loading_utils import load_state_dict | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.utils import ( | |
| CONFIG_NAME, | |
| SAFE_WEIGHTS_INDEX_NAME, | |
| SAFETENSORS_WEIGHTS_NAME, | |
| _get_checkpoint_shard_files, | |
| is_accelerate_available, | |
| ) | |
| from optimum.quanto.models.shared_dict import ShardedStateDict | |
| class QuantizedDiffusersModel(ModelHubMixin): | |
| """Base class for quantized diffusers models.""" | |
| BASE_NAME = "quanto" | |
| base_class = None | |
| def __init__(self, model: ModelMixin): | |
| if not isinstance(model, ModelMixin) or len(quantization_map(model)) == 0: | |
| raise ValueError("The source model must be a quantized diffusers model.") | |
| self._wrapped = model | |
| def __getattr__(self, name: str) -> Any: | |
| """If an attribute is not found in this class, look in the wrapped module.""" | |
| try: | |
| return super().__getattr__(name) | |
| except AttributeError: | |
| wrapped = self.__dict__["_wrapped"] | |
| return getattr(wrapped, name) | |
| def forward(self, *args, **kwargs): | |
| return self._wrapped.forward(*args, **kwargs) | |
| def __call__(self, *args, **kwargs): | |
| return self._wrapped.forward(*args, **kwargs) | |
| def _qmap_name(): | |
| return f"{QuantizedDiffusersModel.BASE_NAME}_qmap.json" | |
| def quantize( | |
| cls, | |
| model: ModelMixin, | |
| weights: Optional[Union[str, qtype]] = None, | |
| activations: Optional[Union[str, qtype]] = None, | |
| optimizer: Optional[Optimizer] = None, | |
| include: Optional[Union[str, List[str]]] = None, | |
| exclude: Optional[Union[str, List[str]]] = None, | |
| ): | |
| """Quantize the specified model.""" | |
| if not isinstance(model, ModelMixin): | |
| raise ValueError("The source model must be a diffusers model.") | |
| quantize( | |
| model, weights=weights, activations=activations, optimizer=optimizer, include=include, exclude=exclude | |
| ) | |
| freeze(model) | |
| return cls(model) | |
| def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): | |
| if cls.base_class is None: | |
| raise ValueError("The `base_class` attribute needs to be configured.") | |
| if not is_accelerate_available(): | |
| raise ValueError("Reloading a quantized diffusers model requires the accelerate library.") | |
| from accelerate import init_empty_weights | |
| if os.path.isdir(pretrained_model_name_or_path): | |
| working_dir = pretrained_model_name_or_path | |
| else: | |
| working_dir = snapshot_download(pretrained_model_name_or_path, **kwargs) | |
| # Look for a quantization map | |
| qmap_path = os.path.join(working_dir, cls._qmap_name()) | |
| if not os.path.exists(qmap_path): | |
| raise ValueError( | |
| f"No quantization map found in {pretrained_model_name_or_path}: is this a quantized model ?" | |
| ) | |
| # Look for original model config file. | |
| model_config_path = os.path.join(working_dir, CONFIG_NAME) | |
| if not os.path.exists(model_config_path): | |
| raise ValueError(f"{CONFIG_NAME} not found in {pretrained_model_name_or_path}.") | |
| with open(qmap_path, "r", encoding="utf-8") as f: | |
| qmap = json.load(f) | |
| with open(model_config_path, "r", encoding="utf-8") as f: | |
| original_model_cls_name = json.load(f)["_class_name"] | |
| configured_cls_name = cls.base_class.__name__ | |
| if configured_cls_name != original_model_cls_name: | |
| raise ValueError( | |
| f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})." | |
| ) | |
| # Create an empty model | |
| config = cls.base_class.load_config(pretrained_model_name_or_path, **kwargs) | |
| with init_empty_weights(): | |
| model = cls.base_class.from_config(config) | |
| # Look for the index of a sharded checkpoint | |
| checkpoint_file = os.path.join(working_dir, SAFE_WEIGHTS_INDEX_NAME) | |
| if os.path.exists(checkpoint_file): | |
| # Convert the checkpoint path to a list of shards | |
| _, sharded_metadata = _get_checkpoint_shard_files(working_dir, checkpoint_file) | |
| # Create a mapping for the sharded safetensor files | |
| state_dict = ShardedStateDict(working_dir, sharded_metadata["weight_map"]) | |
| else: | |
| # Look for a single checkpoint file | |
| checkpoint_file = os.path.join(working_dir, SAFETENSORS_WEIGHTS_NAME) | |
| if not os.path.exists(checkpoint_file): | |
| raise ValueError(f"No safetensor weights found in {pretrained_model_name_or_path}.") | |
| # Get state_dict from model checkpoint | |
| state_dict = load_state_dict(checkpoint_file) | |
| # Requantize and load quantized weights from state_dict | |
| requantize(model, state_dict=state_dict, quantization_map=qmap) | |
| model.eval() | |
| return cls(model) | |
| def _save_pretrained(self, save_directory: Path) -> None: | |
| self._wrapped.save_pretrained(save_directory) | |
| # Save quantization map to be able to reload the model | |
| qmap_name = os.path.join(save_directory, self._qmap_name()) | |
| qmap = quantization_map(self._wrapped) | |
| with open(qmap_name, "w", encoding="utf8") as f: | |
| json.dump(qmap, f, indent=4) | |
| # Import Flux2Transformer2DModel | |
| from diffusers.models.transformers.transformer_flux2 import Flux2Transformer2DModel | |
| class QuantizedFlux2Transformer2DModel(QuantizedDiffusersModel): | |
| """Quantized FLUX.2 Transformer model.""" | |
| base_class = Flux2Transformer2DModel | |