Instructions to use wikeeyang/Hunyuan-Image-30-Qint4 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use wikeeyang/Hunyuan-Image-30-Qint4 with Transformers:
# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("wikeeyang/Hunyuan-Image-30-Qint4", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| import warnings | |
| import random | |
| from typing import List, Optional, Union, Dict, Any | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| from diffusers.utils import BaseOutput | |
| def default(value, default_value): | |
| return value if value is not None else default_value | |
| def ensure_list(value): | |
| if value is None: | |
| return [] | |
| if isinstance(value, (list, tuple)): | |
| return list(value) | |
| return [value] | |
| class Resolution(object): | |
| def __init__(self, size, *args): | |
| if isinstance(size, str): | |
| if 'x' in size: | |
| size = size.split('x') | |
| size = (int(size[0]), int(size[1])) | |
| else: | |
| size = int(size) | |
| if len(args) > 0: | |
| size = (size, args[0]) | |
| if isinstance(size, int): | |
| size = (size, size) | |
| self.h = self.height = size[0] | |
| self.w = self.width = size[1] | |
| self.r = self.ratio = self.height / self.width | |
| def __getitem__(self, idx): | |
| if idx == 0: | |
| return self.h | |
| elif idx == 1: | |
| return self.w | |
| else: | |
| raise IndexError(f'Index {idx} out of range') | |
| def __str__(self): | |
| return f'{self.h}x{self.w}' | |
| class ResolutionGroup(object): | |
| def __init__(self, base_size=None, step=None, align=1): | |
| self.align = align | |
| self.base_size = base_size | |
| assert base_size % align == 0, f'base_size {base_size} is not divisible by align {align}' | |
| if base_size is not None and not isinstance(base_size, int): | |
| raise ValueError(f'base_size must be None or int, but got {type(base_size)}') | |
| if step is None: | |
| step = base_size // 16 | |
| if step is not None and step > base_size // 2: | |
| raise ValueError(f'step must be smaller than base_size // 2, but got {step} > {base_size // 2}') | |
| self.step = step | |
| self.data = self._calc_by_step() | |
| self.ratio = np.array([x.ratio for x in self.data]) | |
| self.attr = ['' for _ in range(len(self.data))] | |
| self.prefix_space = 0 | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| def __repr__(self): | |
| prefix = self.prefix_space * ' ' | |
| prefix_close = (self.prefix_space - 4) * ' ' | |
| res_str = f'ResolutionGroup(base_size={self.base_size}, step={self.step}, data=' | |
| attr_maxlen = max([len(x) for x in self.attr] + [5]) | |
| res_str += \ | |
| f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}' | |
| res_str += \ | |
| ('\n' + prefix).join([f'{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} ' | |
| f'({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}' | |
| for i, x in enumerate(self.data)]) | |
| res_str += f'\n{prefix_close})' | |
| return res_str | |
| def _calc_by_step(self): | |
| assert self.align <= self.step, f'align {self.align} must be smaller than step {self.step}' | |
| min_height = self.base_size // 2 | |
| min_width = self.base_size // 2 | |
| max_height = self.base_size * 2 | |
| max_width = self.base_size * 2 | |
| resolutions = [Resolution(self.base_size, self.base_size)] | |
| cur_height, cur_width = self.base_size, self.base_size | |
| while True: | |
| if cur_height >= max_height and cur_width <= min_width: | |
| break | |
| cur_height = min(cur_height + self.step, max_height) | |
| cur_width = max(cur_width - self.step, min_width) | |
| resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) | |
| cur_height, cur_width = self.base_size, self.base_size | |
| while True: | |
| if cur_height <= min_height and cur_width >= max_width: | |
| break | |
| cur_height = max(cur_height - self.step, min_height) | |
| cur_width = min(cur_width + self.step, max_width) | |
| resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) | |
| resolutions = sorted(resolutions, key=lambda x: x.ratio) | |
| return resolutions | |
| def get_target_size(self, width, height): | |
| ratio = height / width | |
| idx = np.argmin(np.abs(self.ratio - ratio)) | |
| reso = self.data[idx] | |
| return reso.w, reso.h | |
| def get_base_size_and_ratio_index(self, width, height): | |
| ratio = height / width | |
| idx = np.argmin(np.abs(self.ratio - ratio)) | |
| return self.base_size, idx | |
| class ImageInfo: | |
| """ Class to store image information for processing and generation. """ | |
| def __init__( | |
| self, | |
| image_type: str = None, | |
| image_tensor: torch.Tensor = None, | |
| image_width: int = None, | |
| image_height: int = None, | |
| token_width: int = None, | |
| token_height: int = None, | |
| image_token_length: int = None, | |
| base_size: int = None, | |
| ratio_index: int = None, | |
| **kwargs, | |
| ): | |
| self.image_type = image_type | |
| self.image_tensor = image_tensor | |
| self.image_width = image_width | |
| self.w = image_width | |
| self.image_height = image_height | |
| self.h = image_height | |
| self.token_width = token_width | |
| self.tk_w = token_width | |
| self.token_height = token_height | |
| self.tk_h = token_height | |
| self.image_token_length = default( | |
| image_token_length, | |
| token_width * token_height if token_width is not None and token_height is not None else None | |
| ) | |
| self.base_size = base_size | |
| self.ratio_index = ratio_index | |
| self.add_timestep_token = kwargs.get("add_timestep_token", True) | |
| self.add_guidance_token = kwargs.get("add_guidance_token", False) | |
| self.use_front_boi_token = kwargs.get("use_front_boi_token", True) | |
| self.add_image_shape_token = kwargs.get("add_image_shape_token", True) | |
| def __getitem__(self, key: str) -> Any: | |
| """Allow dictionary-like access to attributes.""" | |
| if hasattr(self, key): | |
| return getattr(self, key) | |
| raise KeyError(f"Key '{key}' not found in ImageInfo") | |
| def __setitem__(self, key: str, value: Any) -> None: | |
| """Allow dictionary-like assignment to attributes.""" | |
| if hasattr(self, key): | |
| setattr(self, key, value) | |
| else: | |
| raise KeyError(f"Key '{key}' not found in ImageInfo") | |
| def __contains__(self, key: str) -> bool: | |
| """Check if the key exists in the ImageInfo object.""" | |
| return hasattr(self, key) | |
| def __repr__(self): | |
| return (f"ImageInfo(image_type={self.image_type}, image_tensor={self.image_tensor}, " | |
| f"image_width={self.image_width}, image_height={self.image_height}, " | |
| f"token_width={self.token_width}, token_height={self.token_height}, " | |
| f"image_token_length={self.image_token_length}, " | |
| f"base_size={self.base_size}, ratio_index={self.ratio_index}") | |
| def meta_info(self): | |
| # Used for image sections of tkwrapper.encode_general() | |
| if self.image_type in ["vae", "gen_image"]: | |
| return dict( | |
| token_length=self.image_token_length, | |
| add_timestep_token=self.add_timestep_token, | |
| add_guidance_token=self.add_guidance_token, | |
| use_front_boi_token=self.use_front_boi_token, | |
| add_image_shape_token=self.add_image_shape_token, | |
| base_size=self.base_size, | |
| ratio_idx=self.ratio_index, | |
| # for rope 2d | |
| token_height=self.token_height, | |
| token_width=self.token_width, | |
| # for bc | |
| image_height=self.image_height, | |
| image_width=self.image_width, | |
| ) | |
| elif self.image_type in ["vit"]: | |
| return dict( | |
| token_length=self.image_token_length, | |
| use_front_boi_token=self.use_front_boi_token, | |
| add_image_shape_token=self.add_image_shape_token, | |
| # for rope 2d | |
| token_height=self.token_height, | |
| token_width=self.token_width, | |
| # for bc | |
| image_height=self.image_height, | |
| image_width=self.image_width, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown image type '{self.image_type}'") | |
| def num_special_tokens(self): | |
| if self.args is None: | |
| raise ValueError("meta_info requires `args` attribute to be set.") | |
| if self.image_type in ["vae", "src_image", "gen_image"]: | |
| count = ( | |
| 2 + # <boi> + <eoi> or <src_boi> + <src_eoi> | |
| (1 if self.add_timestep_token else 0) + | |
| (1 if self.add_guidance_token else 0) + | |
| (2 if self.add_image_shape_token else 0) | |
| ) | |
| else: | |
| raise ValueError(f"Unknown image_type: {self.image_type}") | |
| return count | |
| def copy(self, copy_image_tensor=True): | |
| if copy_image_tensor and self.image_tensor is None: | |
| raise ValueError("image_tensor is None, cannot copy") | |
| return ImageInfo( | |
| image_type=self.image_type, | |
| image_tensor=self.image_tensor.clone() if copy_image_tensor else None, | |
| image_width=self.image_width, | |
| image_height=self.image_height, | |
| token_width=self.token_width, | |
| token_height=self.token_height, | |
| image_token_length=self.image_token_length, | |
| base_size=self.base_size, | |
| ratio_index=self.ratio_index, | |
| ) | |
| def zeros_(self): | |
| self.image_tensor = torch.zeros_like(self.image_tensor) | |
| class ImageTensor(torch.Tensor): | |
| # This class is just for type hinting purposes. Attribute `i` should be defined | |
| # as an instance attribute of the torch.Tensor instance, like: tensor.i = ImageInfo(...) | |
| i: ImageInfo | |
| vision_encoder_kwargs: dict | |
| class JointImageInfo(object): | |
| def __init__(self, vae_image_info: ImageInfo, vision_image_info: ImageInfo, vision_encoder_kwargs: dict = None): | |
| self.vae_image_info = vae_image_info | |
| self.vision_image_info = vision_image_info | |
| self.vision_encoder_kwargs = vision_encoder_kwargs | |
| # Define key attributes to align with ImageInfo for uniformity | |
| self.image_type = "joint_image" | |
| self.image_token_length = vae_image_info.image_token_length + vision_image_info.image_token_length | |
| self.add_timestep_token = vae_image_info.add_timestep_token | |
| self.use_front_boi_token = vae_image_info.use_front_boi_token | |
| self.add_image_shape_token = vae_image_info.add_image_shape_token | |
| def __repr__(self): | |
| return f"JointImageInfo(vae_image={self.vae_image_info}, vision_image={self.vision_image_info})" | |
| def meta_info(self): | |
| # Used for image sections of tkwrapper.encode_general() | |
| return dict( | |
| token_length=[self.vae_image_info.image_token_length, self.vision_image_info.image_token_length], | |
| add_timestep_token=self.add_timestep_token, | |
| use_front_boi_token=self.use_front_boi_token, | |
| add_image_shape_token=self.add_image_shape_token, | |
| base_size=self.vae_image_info.base_size, | |
| ratio_idx=self.vae_image_info.ratio_index, | |
| # for rope 2d | |
| token_height=[self.vae_image_info.token_height, self.vision_image_info.token_height], | |
| token_width=[self.vae_image_info.token_width, self.vision_image_info.token_width], | |
| # for bc | |
| image_height=[self.vae_image_info.image_height, self.vision_image_info.image_height], | |
| image_width=[self.vae_image_info.image_width, self.vision_image_info.image_width], | |
| ) | |
| def num_special_tokens(self): | |
| return ( | |
| 2 + # <boi> + <eoi> | |
| (1 if self.add_timestep_token else 0) + | |
| (2 if self.add_image_shape_token else 0) + | |
| 1 # <joint_image_sep> | |
| ) | |
| def copy(self, copy_image_tensor=True): | |
| if copy_image_tensor and ( | |
| self.vae_image_info.image_tensor is None or self.vision_image_info.image_tensor is None): | |
| raise ValueError("image_tensor is None, cannot copy") | |
| return JointImageInfo( | |
| self.vae_image_info.copy(copy_image_tensor), | |
| self.vision_image_info.copy(copy_image_tensor), | |
| self.vision_encoder_kwargs, | |
| ) | |
| def zeros_(self): | |
| self.vae_image_info.zeros_() | |
| self.vision_image_info.zeros_() | |
| class JointImage(object): | |
| def __init__(self, vae_image: ImageTensor, vision_image: ImageTensor): | |
| self.vae_image = vae_image | |
| self.vision_image = vision_image | |
| self.i = JointImageInfo(vae_image.i, vision_image.i) | |
| class TokenizerEncodeOutput(BaseOutput): | |
| tokens: torch.Tensor = None | |
| timestep_scatter_index: Optional[torch.Tensor] = None | |
| guidance_scatter_index: Optional[torch.Tensor] = None | |
| text_slices: Optional[List[slice]] = None | |
| gen_image_slices: Optional[List[slice]] = None | |
| joint_image_slices: Optional[List[slice]] = None | |
| cond_vae_image_slices: Optional[List[slice]] = None | |
| cond_vit_image_slices: Optional[List[slice]] = None | |
| text_mask: Optional[torch.Tensor] = None | |
| gen_image_mask: Optional[torch.Tensor] = None | |
| cond_vae_image_mask: Optional[torch.Tensor] = None | |
| cond_vit_image_mask: Optional[torch.Tensor] = None | |
| real_pos: Optional[torch.Tensor] = None | |
| all_image_slices: Optional[List[slice]] = None | |
| cond_timestep_scatter_index: Optional[torch.Tensor] = None | |
| gen_timestep_scatter_index: Optional[torch.Tensor] = None | |
| class Conversation: | |
| roles: List[str] = ["User", "Assistant"] | |
| sep: str = "\n\n" | |
| class TokenizerWrapper(object): | |
| def __init__(self, tokenizer): | |
| if isinstance(tokenizer, str): | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
| else: | |
| self.tokenizer = tokenizer | |
| # Define short names | |
| self.bos_token_id = self.tokenizer.bos_token_id | |
| self.eos_token_id = self.tokenizer.eos_token_id | |
| self.pad_token_id = self.tokenizer.pad_token_id | |
| self.boi_token_id = self.tokenizer.convert_tokens_to_ids("<boi>") | |
| self.eoi_token_id = self.tokenizer.convert_tokens_to_ids("<eoi>") | |
| self.img_token_id = self.tokenizer.convert_tokens_to_ids("<img>") | |
| self.cfg_token_id = self.tokenizer.convert_tokens_to_ids("<cfg>") | |
| self.end_answer_token_id = self.tokenizer.convert_tokens_to_ids("</answer>") | |
| self.end_recaption_token_id = self.tokenizer.convert_tokens_to_ids("</recaption>") | |
| self.ratio_token_offset = self.tokenizer.convert_tokens_to_ids("<img_ratio_0>") | |
| self.special_token_map = self.tokenizer.added_tokens_encoder | |
| def pad(self, tensor_list, dim=0, pad_val=None): | |
| if pad_val is None: | |
| pad_val = self.pad_token_id | |
| max_len = max([t.shape[dim] for t in tensor_list]) | |
| padded_tensor_list = [] | |
| for t in tensor_list: | |
| if t.shape[dim] < max_len: | |
| assert pad_val is not False, "Not allowed pad." | |
| t = F.pad(t, (0, max_len - t.shape[dim]), value=pad_val) | |
| padded_tensor_list.append(t) | |
| return padded_tensor_list | |
| def encode(self, *args, **kwargs): | |
| return self.tokenizer.encode(*args, **kwargs) | |
| def decode(self, *args, **kwargs): | |
| return self.tokenizer.decode(*args, **kwargs) | |
| def encode_text( | |
| self, | |
| *texts, | |
| uncond_enabled: Optional[Union[bool, List[bool]]] = None, | |
| uncond_p: Optional[float] = None, | |
| max_length: Optional[int] = None, | |
| pad: Optional[str] = None, | |
| return_lengths: bool = False, | |
| ): | |
| """ | |
| Encode text and image for AR-like model training of the text-to-image/instruction tuning tasks. | |
| Support encode multiple texts at once. Each text can be separately conditioned or unconditioned | |
| based on the uncond_flags and a uniform uncond_p. | |
| **<bos> token is always prepended to the text tokens.** | |
| Parameters | |
| ---------- | |
| texts: str or List[str] | |
| List of texts to be encoded. | |
| uncond_enabled: bool or List[bool] | |
| List of flags to indicate whether the text should be unconditioned. | |
| If False, the text will never be unconditioned. | |
| If True, the text will be unconditioned with uncond_p. | |
| uncond_p: float | |
| Probability to the unconditional text. Only works when uncond_enabled is True. | |
| max_length: int | |
| Maximum length of the encoded text. | |
| pad: Optional[str] | |
| Padding method. Can be 'left' or 'right'. | |
| return_lengths: bool | |
| Whether to return the length of each encoded text. | |
| """ | |
| if pad is not None: | |
| assert max_length is not None, "max_length should be provided when pad is not None." | |
| if uncond_enabled is None: | |
| uncond_enabled = [True] * len(texts) | |
| elif isinstance(uncond_enabled, bool): | |
| uncond_enabled = [uncond_enabled] * len(texts) | |
| if len(uncond_enabled) != len(texts): | |
| print(uncond_enabled, texts) | |
| assert len(uncond_enabled) == len(texts), ( | |
| f"Length of uncond_flags should be equal to the number of texts, " | |
| f"but got {len(uncond_enabled)} and {len(texts)}." | |
| ) | |
| # Prepare text/uncond tokens | |
| # TODO: If len(texts) > 1, such as instruction + prompt in inpainting, we need to determine how to do uncond. | |
| # Now all texts will be cond or uncond at the same time. | |
| do_uncond_drop = (uncond_p is not None) and (random.random() < uncond_p) | |
| text_tokens, lengths = [], [] | |
| cum_length = 0 | |
| for text, uncond_flag in zip(texts, uncond_enabled): | |
| # If reach the max_length and there still have unencoded texts, give a warning message and break the loop. | |
| if max_length is not None and cum_length >= max_length: | |
| warnings.warn( | |
| f"Text length exceeds the max_length({max_length}). The remaining texts will be ignored: " | |
| f"{text[:80]}..." | |
| ) | |
| break | |
| # Set add_special_tokens=False to avoid adding <bos> token in some LLMs. | |
| if isinstance(text, str): | |
| text_token = self.tokenizer.encode(text, add_special_tokens=False) | |
| else: | |
| text_token = text | |
| if uncond_flag and do_uncond_drop: | |
| text_token = [self.cfg_token_id] * len(text_token) | |
| # Cutoff the text by max_length if necessary | |
| if max_length is not None and (cum_length + len(text_token)) > max_length: | |
| text_token = text_token[:max_length - cum_length] | |
| text_tokens.extend(text_token) | |
| lengths.append(len(text_token)) | |
| cum_length += len(text_token) | |
| # Prepend/Append <pad> tokens if applicable | |
| if pad is not None and (pad_length := max_length - len(text_tokens)) > 0: | |
| if pad == 'left': | |
| text_tokens = [self.pad_token_id] * pad_length + text_tokens | |
| elif pad == 'right': | |
| text_tokens = text_tokens + [self.pad_token_id] * pad_length | |
| else: | |
| raise ValueError(f"Unsupported padding method: {pad}.") | |
| if return_lengths: | |
| return text_tokens, lengths | |
| return text_tokens | |
| def _check_key_number_matched(keys, data): | |
| # Assert keys and token_source are matched | |
| assert set(keys) == set(data.keys()), ( | |
| f"Keys in the template and token source should be matched, but got {set(keys)} and {list(data.keys())}." | |
| ) | |
| key_counts = {k: 0 for k in keys} | |
| for key in keys: | |
| key_counts[key] += 1 | |
| for key, count in key_counts.items(): | |
| assert len(data[key]) == count, ( | |
| f"Number of `{key}` in the token source should be matched with the template, but got " | |
| f"{data[key]}({len(data[key])}) and {count}." | |
| ) | |
| def _add_image_meta_info_token(self, token_seq, token_count, extra_token_pos, add_timestep_token=False, | |
| add_image_shape_token=False, base_size=None, ratio_idx=None, image_type=None, | |
| add_guidance_token=False): | |
| if add_image_shape_token: | |
| token_seq.extend([ | |
| self.special_token_map[f"<img_size_{base_size}>"], | |
| self.special_token_map[f"<img_ratio_{ratio_idx}>"] | |
| ]) | |
| token_count += 2 | |
| if add_timestep_token: | |
| token_seq.extend([self.special_token_map["<timestep>"]]) | |
| extra_token_pos['timestep'].append(token_count) | |
| if image_type is not None: | |
| if image_type == "gen_image": | |
| extra_token_pos['gen_timestep'].append(token_count) | |
| elif image_type in ["joint_image"]: | |
| extra_token_pos['cond_timestep'].append(token_count) | |
| else: | |
| raise ValueError(f"Unsupported image type: {image_type}.") | |
| token_count += 1 | |
| if add_guidance_token: | |
| token_seq.extend([self.special_token_map["<guidance>"]]) | |
| extra_token_pos['guidance'].append(token_count) | |
| token_count += 1 | |
| return token_count | |
| def _shorten_text(text): | |
| import re | |
| text = re.sub(r"(<img>)+", lambda m: f"[<img>]{{{len(m.group(0)) // 5}}}", text) | |
| text = re.sub(r"(<pad>)+", lambda m: f"[<pad>]{{{len(m.group(0)) // 5}}}", text) | |
| return text | |
| def encode_sequence( | |
| self, | |
| template: str, | |
| token_source: Dict[str, List], | |
| total_length=None, | |
| add_timestep_token=False, | |
| add_guidance_token=False, | |
| last_key_only_prefix=False, | |
| add_eos=True, | |
| use_front_boi_token=True, | |
| add_pad=True, | |
| add_bos=True, | |
| drop_last: Union[str, bool] = 'auto', | |
| add_image_shape_token=False, | |
| ): | |
| """ | |
| Encode a sequence based on the template (e.g., `text-image` for t2i, `text-image-image` for instruction tuning) | |
| and token source. | |
| Parameters | |
| ---------- | |
| template: str | |
| Template of the sequence. E.g., "text-gen_image" means the sequence is composed of text and an image. | |
| "text-text-gen_image" means the sequence is composed of two sections of text and an image. | |
| token_source: Dict[str, List] | |
| Token source for each key in the template, in order. | |
| - text: List[Dict]. | |
| - gen_image: List[Dict]. | |
| - joint_image: List[Dict]. | |
| total_length: int | |
| Total length of the encoded sequence, include padding tokens. | |
| add_timestep_token: bool | |
| Whether to add timestep token before the image tokens. | |
| (Right after the <img_ratio_*><img_size_*> tokens) | |
| add_guidance_token: bool | |
| Whether to add guidance token before the image tokens. | |
| last_key_only_prefix: bool | |
| Whether to only use the modal prefix in the last key. | |
| add_eos: bool or 'auto' | |
| Whether to add eos token at the end of the sequence. If True, always add eos token. If 'auto', | |
| add eos token only when the total_length is not reached and the last token is not <eos>. | |
| use_front_boi_token: bool: | |
| Whether to put the <boi> token at the front of iw, ih and timestep tokens. | |
| add_pad: bool or 'auto' | |
| Whether to add padding tokens to the sequence. If True and total_length is not reached, add padding tokens. | |
| add_bos: bool | |
| Whether to add bos token at the beginning of the sequence. | |
| drop_last: bool or 'auto' | |
| - If auto, drop last tokens exceeding the total_length if the total_length is provided. If cut point is | |
| in the middle of the image tokens, an error will raised. | |
| - If True, drop last tokens exceeding the total_length. If cut point is in the middle of the image tokens, | |
| all the successive image tokens will be dropped. | |
| - If False, keep the last tokens exceeding the total_length, even if the total_length is reached. | |
| add_image_shape_token: bool | |
| Whether to add image shape token before the image tokens. (Right before the <timestep> token) | |
| Returns | |
| ------- | |
| token_seq: list | |
| Encoded token sequence. | |
| extra_token_pos: dict | |
| Positions of extra tokens. | |
| """ | |
| if last_key_only_prefix: | |
| assert add_eos is not True, "add_eos should not be True when last_key_only_prefix is True." | |
| if drop_last is True and total_length is None: | |
| raise ValueError("total_length should be provided when drop_last is True.") | |
| keys = template.split('-') | |
| modal_length = len(keys) | |
| index_indicator = {k: 0 for k in token_source} | |
| for k, v in token_source.items(): | |
| assert isinstance(v, (list, tuple)), ( | |
| f"Value of `{k}` in the token source should be a list or tuple, but got {type(v)}." | |
| ) | |
| self._check_key_number_matched(keys, token_source) | |
| token_seq = [] | |
| token_count = 0 | |
| extra_token_pos = defaultdict(list) | |
| if add_bos: | |
| token_seq.append(self.bos_token_id) | |
| token_count += 1 | |
| # If drop_last is True, we check the token_count on the fly and exit the loop if the total_length is reached. | |
| # This check is only applied to the block tokens. Block tokens mean the tokens that are unsplittable, like | |
| # image tokens. Text tokens are splittable, so we don't need to check the token_count for text. | |
| # If the loop is broken by drop_last, we don't add the eos token at the end because the sequence is not | |
| # complete. | |
| drop_last_break = False | |
| for i, key in enumerate(keys): | |
| source = token_source[key][index_indicator[key]] | |
| if key == "text": | |
| token_seq.extend(source) # text token sequence | |
| extra_token_pos["<text>_start"].append(token_count) | |
| token_count += len(source) | |
| extra_token_pos["<text>_end"].append(token_count - 1) | |
| elif key == "gen_image": | |
| if isinstance(source, int): | |
| source = {'length': source} | |
| extra_count = 2 + ( | |
| 1 if source.get('timestep', add_timestep_token) else 0) + ( | |
| 1 if source.get('guidance', add_guidance_token) else 0) + ( | |
| 2 if source.get('image_shape', add_image_shape_token) else 0 | |
| ) | |
| if drop_last is True and token_count + extra_count + source['length'] > total_length: | |
| drop_last_break = True | |
| break | |
| if source.get('front_boi', use_front_boi_token): | |
| token_seq.append(self.boi_token_id) | |
| extra_token_pos["boi"].append(token_count) | |
| token_count += 1 | |
| token_count = self._add_image_meta_info_token( | |
| token_seq=token_seq, | |
| token_count=token_count, | |
| extra_token_pos=extra_token_pos, | |
| add_timestep_token=source.get('timestep', add_timestep_token), | |
| add_guidance_token=source.get('guidance', add_guidance_token), | |
| add_image_shape_token=source.get('image_shape', add_image_shape_token), | |
| base_size=source.get('base_size'), | |
| ratio_idx=source.get('ratio_idx'), | |
| image_type=key, | |
| ) | |
| if not source.get('front_boi', use_front_boi_token): | |
| token_seq.append(self.boi_token_id) | |
| extra_token_pos["boi"].append(token_count) | |
| token_count += 1 | |
| if last_key_only_prefix and i == modal_length - 1: | |
| pass # for AR inference | |
| else: | |
| token_seq.extend( | |
| [self.img_token_id] * source['length'] + # token number | |
| [self.eoi_token_id] | |
| ) | |
| extra_token_pos["<img>_start"].append(token_count) | |
| extra_token_pos["<all_img>_start"].append(token_count) | |
| token_count += source['length'] | |
| extra_token_pos["<img>_end"].append(token_count - 1) | |
| extra_token_pos["<all_img>_end"].append(token_count - 1) | |
| extra_token_pos["eoi"].append(token_count) | |
| token_count += 1 # <eoi> | |
| elif key == "joint_image": | |
| assert isinstance(source['length'], list) and len( | |
| source['length']) == 2, "joint_image length should be a list of two integers" | |
| extra_count = 2 + 1 + ( # boi, eoi, joint_img_sep | |
| 1 if source.get('timestep', add_timestep_token) else 0) + ( | |
| 2 if source.get('image_shape', add_image_shape_token) else 0 | |
| ) | |
| if drop_last is True and token_count + extra_count + sum(source['length']) > total_length: | |
| drop_last_break = True | |
| break | |
| if source.get('front_boi', use_front_boi_token): | |
| token_seq.append(self.boi_token_id) # Use patched boi for Janus, otherwise useing default <boi> | |
| extra_token_pos["boi"].append(token_count) | |
| token_count += 1 | |
| token_count = self._add_image_meta_info_token( | |
| token_seq=token_seq, | |
| token_count=token_count, | |
| extra_token_pos=extra_token_pos, | |
| add_timestep_token=source.get('timestep', add_timestep_token), | |
| add_image_shape_token=source.get('image_shape', add_image_shape_token), | |
| base_size=source.get('base_size'), | |
| ratio_idx=source.get('ratio_idx'), | |
| image_type=key, | |
| ) | |
| if not source.get('front_boi', use_front_boi_token): | |
| token_seq.append(self.boi_token_id) | |
| extra_token_pos["boi"].append(token_count) | |
| token_count += 1 | |
| if last_key_only_prefix and i == modal_length - 1: | |
| pass # for AR inference | |
| else: | |
| token_seq.extend( | |
| [self.img_token_id] * source['length'][0] | |
| ) | |
| extra_token_pos["<vae_img>_start"].append(token_count) | |
| extra_token_pos["<joint_img>_start"].append(token_count) | |
| extra_token_pos["<all_img>_start"].append(token_count) | |
| token_count += source['length'][0] | |
| extra_token_pos["<vae_img>_end"].append(token_count - 1) | |
| extra_token_pos["<all_img>_end"].append(token_count - 1) | |
| token_seq.extend( | |
| [self.special_token_map["<joint_img_sep>"]] | |
| ) | |
| extra_token_pos["joint_img_sep"].append(token_count) | |
| token_count += 1 | |
| token_seq.extend( | |
| [self.img_token_id] * source['length'][1] | |
| ) | |
| extra_token_pos["<vit_img>_start"].append(token_count) | |
| extra_token_pos["<all_img>_start"].append(token_count) | |
| token_count += source['length'][1] | |
| extra_token_pos["<vit_img>_end"].append(token_count - 1) | |
| extra_token_pos["<joint_img>_end"].append(token_count - 1) | |
| extra_token_pos["<all_img>_end"].append(token_count - 1) | |
| token_seq.extend( | |
| [self.eoi_token_id] | |
| ) | |
| extra_token_pos["eoi"].append(token_count) | |
| token_count += 1 # <eoi> | |
| else: | |
| raise ValueError(f"Not supported key: {key}") | |
| index_indicator[key] += 1 | |
| if add_eos is True and not drop_last_break: | |
| # Typically used for t2i task. | |
| token_seq.append(self.eos_token_id) | |
| extra_token_pos["eos"].append(token_count) | |
| token_count += 1 | |
| elif add_eos == 'auto' and not drop_last_break: | |
| # Typically used for lm and mmu task. | |
| if token_seq[-1] != self.eos_token_id and (total_length is None or token_count < total_length): | |
| token_seq.append(self.eos_token_id) | |
| extra_token_pos["eos"].append(token_count) | |
| token_count += 1 | |
| if total_length: | |
| # Check token count and clip sequence if necessary | |
| if token_count > total_length and drop_last: | |
| # Assert clip position is not in the middle of the block-wise tokens (gen_image, joint_image) | |
| for start_key, end_key in [ | |
| ("<img>_start", "<img>_end"), ("<joint_img>_start", "<joint_img>_end"), | |
| ("<vae_img>_start", "<vae_img>_end"), ("<vit_img>_start", "<vit_img>_end"), | |
| ]: | |
| if start_key in extra_token_pos and end_key in extra_token_pos: | |
| assert all( | |
| (start > total_length or end + 1 < total_length) | |
| for start, end in zip(extra_token_pos[start_key], extra_token_pos[end_key]) | |
| ), ("Clip position should not be in the middle of the image tokens.\n" | |
| f"Below is the text:\n{self._shorten_text(self.tokenizer.decode(token_seq))}") | |
| token_seq = token_seq[:total_length] | |
| # Pad the sequence if necessary | |
| pad_num = max(0, total_length - len(token_seq)) | |
| if add_pad and pad_num: | |
| token_seq.extend([self.pad_token_id] * pad_num) | |
| extra_token_pos["first_pad"].append(token_count) | |
| return token_seq, extra_token_pos | |
| def batch_gen_infer( | |
| self, | |
| infer_fn, | |
| prompt_list: list, | |
| negative_prompt_list: list = None, | |
| infer_fn_kwargs_list: List[Dict[str, int]] = None, | |
| do_classifier_free_guidance=False, | |
| condition_repeat_times: int = 1, | |
| uncondition_repeat_times: int = 1, | |
| ): | |
| """ | |
| Batch inference for the AR-like model training of the text-to-image/instruction tuning tasks. | |
| Parameters | |
| ---------- | |
| infer_fn: callable | |
| Inference function to encode the prompt. | |
| prompt_list: list | |
| List of prompts. Each element can be a single prompt or a list of prompts passed to the infer_fn. | |
| negative_prompt_list: list | |
| List of negative prompts. Only used when do_classifier_free_guidance is True. If None, will use <cfg> | |
| token sequence as negative prompt. | |
| infer_fn_kwargs_list: List[Dict[str, int]] | |
| List of keyword arguments for the infer_fn. | |
| do_classifier_free_guidance: bool | |
| Whether to do classifier-free guidance. | |
| condition_repeat_times: int | |
| Support multi-condition. | |
| uncondition_repeat_times: int | |
| Support multi-uncondition. | |
| """ | |
| if infer_fn_kwargs_list is None: | |
| infer_fn_kwargs_list = [{} for _ in prompt_list] | |
| # [n_output, bsz] | |
| cond_results_list = None | |
| uncond_results_list = None | |
| output_type_list = [] | |
| for prompt_idx, (prompt, infer_fn_kwargs) in enumerate(zip(prompt_list, infer_fn_kwargs_list)): | |
| if not isinstance(prompt, (list, tuple)): | |
| prompt = [prompt] | |
| cond_kwargs = {"uncond_p": 0.0} if do_classifier_free_guidance else {} | |
| results = infer_fn( | |
| *prompt, | |
| **infer_fn_kwargs, | |
| **cond_kwargs, | |
| ) | |
| output_type_list.append((type(results), len(results) if isinstance(results, (list, tuple)) else 1)) | |
| if isinstance(results, dict): | |
| raise ValueError("Make batch on dict is not supported. Please return list or tuple for infer_fn.") | |
| if not isinstance(results, (list, tuple)): | |
| results = (results,) | |
| if cond_results_list is None: | |
| cond_results_list = [[] for _ in results] | |
| uncond_results_list = [[] for _ in results] | |
| for i, result in enumerate(results): | |
| cond_results_list[i].append(result) | |
| if do_classifier_free_guidance: | |
| if negative_prompt_list is None: | |
| uncond_kwargs = {"uncond_p": 1.0} | |
| uncond_results = infer_fn( | |
| *prompt, | |
| **infer_fn_kwargs, | |
| **uncond_kwargs, | |
| ) | |
| else: | |
| negative_prompt = negative_prompt_list[prompt_idx] | |
| if not isinstance(negative_prompt, (list, tuple)): | |
| negative_prompt = [negative_prompt] | |
| uncond_results = infer_fn( | |
| *negative_prompt, | |
| **infer_fn_kwargs, | |
| ) | |
| if isinstance(uncond_results, TokenizerEncodeOutput): | |
| uncond_results_list.append(uncond_results) | |
| else: | |
| for i, result in enumerate(uncond_results): | |
| uncond_results_list[i].append(result) | |
| assert all(output_type_list[0] == n for n in output_type_list), \ | |
| f"Number of outputs should be equal for all samples, but got {output_type_list}." | |
| output_type, output_num = output_type_list[0] | |
| def make_batch(batch_cond_item, batch_uncond_item): | |
| # Process each output item to make batch | |
| first = batch_cond_item[0] # The first element in the batch | |
| if isinstance(first, torch.Tensor): | |
| stacked_item = torch.stack(self.pad( | |
| batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times, | |
| )) | |
| elif first is None: | |
| assert all(item is None for item in batch_cond_item + batch_uncond_item), \ | |
| (f"The first cond item is None, but some items are not None:\n\n" | |
| f"condition: {batch_cond_item}\n\n" | |
| f"uncondition: {batch_uncond_item}") | |
| stacked_item = None | |
| elif isinstance(first, (list, tuple)): | |
| # If the output item is a list or tuple, we treat it as a whole, and won't make nested batch any more. | |
| stacked_item = batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times | |
| elif isinstance(first, TokenizerEncodeOutput): | |
| stacked_item = {} | |
| # Traverse not-None attributes | |
| for key in list(first.keys()): | |
| merged_list = [cond_item[key] for cond_item in batch_cond_item] * condition_repeat_times + \ | |
| [uncond_item[key] for uncond_item in batch_uncond_item] * uncondition_repeat_times | |
| if isinstance(first[key], torch.Tensor): | |
| if 'mask' in key: | |
| pad_val = 0.0 | |
| elif key == 'tokens': | |
| pad_val = self.special_token_map["<pad>"] | |
| else: | |
| pad_val = False # Should not pad for other tensors | |
| stacked_item[key] = torch.stack(self.pad(merged_list, pad_val=pad_val), dim=0) | |
| elif isinstance(first[key], list): | |
| stacked_item[key] = merged_list | |
| elif first[key] is None: | |
| pass | |
| else: | |
| raise ValueError(f"Unsupported type of {key}: {type(first[key])}.") | |
| stacked_item = TokenizerEncodeOutput(stacked_item) | |
| else: | |
| raise TypeError(f"Making batch on type {type(first)} is not supported.") | |
| return stacked_item | |
| stacked_outputs = [] | |
| for cond_results, uncond_results in zip(cond_results_list, uncond_results_list): | |
| stacked_outputs.append(make_batch(cond_results, uncond_results)) | |
| if output_type == list: | |
| return stacked_outputs | |
| elif output_type == tuple: | |
| return tuple(stacked_outputs) | |
| elif output_num == 1: | |
| return stacked_outputs[0] | |
| else: | |
| raise ValueError(f"Unsupported output type: {output_type}.") | |
| def parse_extra_token_pos(extra_token_pos, prefix, tokens, rng=None): | |
| if rng is None: | |
| rng = slice(None) | |
| image_slices = [ | |
| slice(start, end + 1) | |
| for start, end in zip(extra_token_pos[f'<{prefix}>_start'][rng], extra_token_pos[f'<{prefix}>_end'][rng]) | |
| ] if f'<{prefix}>_start' in extra_token_pos and f'<{prefix}>_end' in extra_token_pos else [] | |
| if image_slices: | |
| image_mask = torch.zeros_like(tokens, dtype=torch.bool) | |
| for image_slice in image_slices: | |
| image_mask[image_slice] = True | |
| else: | |
| image_mask = None | |
| return image_slices, image_mask | |
| def encode_general( | |
| self, | |
| sections: Optional[List[Dict[str, Any]]] = None, | |
| max_token_length: Optional[int] = None, | |
| add_eos='auto', | |
| use_text_mask=True, | |
| add_pad='auto', | |
| add_bos=True, | |
| drop_last='auto', | |
| ): | |
| """ | |
| General encode function to encode a sequence with multiple sections of text and images. | |
| Each section is a dict with a `type` key and other keys depending on the type. | |
| Supported section types: | |
| - text: dict with keys: | |
| - text: str or List[int], text to be encoded. Either `text` or `tokens` should be provided. | |
| - tokens: List[int], pre-encoded text tokens. Either `text` or `tokens` should be provided. | |
| - uncond_enabled: bool, whether to enable uncondition for this text section. | |
| - uncond_p: float, probability to drop the text section for uncondition. | |
| - max_length: int, maximum length of the text section. | |
| - ignore: bool, whether to ignore this text section in the text mask. | |
| - start_offset: int, start offset of the text mask. | |
| - end_offset: int, end offset of the text mask. | |
| - gen_image: dict with keys: | |
| - token_length: int, number of image tokens. | |
| - add_timestep_token: bool, whether to add timestep token before the image tokens. | |
| - add_guidance_token: bool, whether to add guidance token before the image tokens. | |
| - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens. | |
| - add_image_shape_token: bool, whether to add image shape token before the image tokens. | |
| - base_size: int, base size of the image. | |
| - ratio_idx: int, ratio index of the image. | |
| - joint_image: dict with keys: | |
| - token_length: List[int], number of image tokens for the two images. | |
| - add_timestep_token: bool, whether to add timestep token before the image tokens. | |
| - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens. | |
| - add_image_shape_token: bool, whether to add image shape token before the image tokens. | |
| - base_size: int, base size of the image. | |
| - ratio_idx: int, ratio index of the image. | |
| Parameters | |
| ---------- | |
| sections: List[Dict[str, Any]] | |
| List of sections to be encoded. | |
| max_token_length: int | |
| Maximum length of the encoded token sequence. | |
| add_eos: bool or 'auto' | |
| Whether to add eos token at the end of the sequence. If True, always add eos | |
| token. If 'auto', add eos token only when the total_length is not reached and the last token is not <eos>. | |
| use_text_mask: bool | |
| Whether to generate text mask. | |
| add_pad: bool or 'auto' | |
| Whether to add padding tokens to the sequence. If True and total_length is not reached, | |
| add padding tokens. | |
| add_bos: bool | |
| Whether to add bos token at the beginning of the sequence. | |
| drop_last: bool or 'auto' | |
| - If auto, drop last tokens exceeding the total_length if the total_length is provided. | |
| If cut point is in the middle of the image tokens, an error will raised. | |
| - If True, drop last tokens exceeding the total_length. If cut point is in the | |
| middle of the image tokens, all the successive image tokens will be dropped. | |
| - If False, keep the last tokens exceeding the total_length, even if the total_length | |
| is reached. | |
| Returns | |
| ------- | |
| TokenizerEncodeOutput | |
| Encoded token sequence and extra information. | |
| """ | |
| if sections is None: | |
| raise ValueError("sections must be provided.") | |
| template = '-'.join([section['type'] for section in sections]) | |
| sections = deepcopy(sections) | |
| token_source = defaultdict(list) | |
| text_mask_specs = [] | |
| for section in sections: | |
| if section['type'] == 'text': | |
| text = self.encode_text( | |
| section['text'] if 'text' in section else section['tokens'], | |
| uncond_enabled=section.get('uncond_enabled'), | |
| uncond_p=section.get('uncond_p'), | |
| max_length=section.get('max_length'), | |
| ) | |
| token_source['text'].append(text) | |
| text_mask_specs.append(dict( | |
| ignore=section.get('ignore', False), | |
| start_offset=section.get('start_offset', 0), | |
| end_offset=section.get('end_offset', 0), | |
| )) | |
| elif section['type'] == 'gen_image': | |
| token_source['gen_image'].append(dict( | |
| length=section['token_length'], | |
| timestep=section.get('add_timestep_token', False), | |
| guidance=section.get('add_guidance_token', False), | |
| front_boi=section.get('use_front_boi_token', False), | |
| image_shape=section.get('add_image_shape_token', False), | |
| base_size=section.get('base_size'), | |
| ratio_idx=section.get('ratio_idx'), | |
| )) | |
| elif section['type'] == 'joint_image': | |
| token_source['joint_image'].append(dict( | |
| length=section['token_length'], | |
| timestep=section.get('add_timestep_token', False), | |
| front_boi=section.get('use_front_boi_token', False), | |
| image_shape=section.get('add_image_shape_token', False), | |
| base_size=section.get('base_size'), | |
| ratio_idx=section.get('ratio_idx'), | |
| )) | |
| else: | |
| raise ValueError(f"Invalid section type: {section['type']}") | |
| # Combine text and image tokens | |
| full_token_seq, extra_token_pos = self.encode_sequence( | |
| template=template, | |
| token_source=dict(token_source), | |
| total_length=max_token_length, | |
| add_eos=add_eos, | |
| add_pad=add_pad, | |
| add_bos=add_bos, | |
| drop_last=drop_last, | |
| ) | |
| full_seq_token_tensor = torch.tensor(full_token_seq, dtype=torch.long) | |
| timestep_scatter_index = torch.tensor(extra_token_pos['timestep'], dtype=torch.long) \ | |
| if 'timestep' in extra_token_pos else None | |
| guidance_scatter_index = torch.tensor(extra_token_pos['guidance'], dtype=torch.long) \ | |
| if 'guidance' in extra_token_pos else None | |
| cond_timestep_scatter_index = torch.tensor(extra_token_pos['cond_timestep'], dtype=torch.long) \ | |
| if 'cond_timestep' in extra_token_pos else None | |
| gen_timestep_scatter_index = torch.tensor(extra_token_pos['gen_timestep'], dtype=torch.long) \ | |
| if 'gen_timestep' in extra_token_pos else None | |
| # Gen image mask | |
| gen_image_slices, gen_image_mask = self.parse_extra_token_pos(extra_token_pos, 'img', full_seq_token_tensor) | |
| # Joint image | |
| joint_image_slices, _ = self.parse_extra_token_pos(extra_token_pos, 'joint_img', full_seq_token_tensor) | |
| # Conditional vae image | |
| cond_vae_image_slices, cond_vae_image_mask = self.parse_extra_token_pos( | |
| extra_token_pos, 'vae_img', full_seq_token_tensor) | |
| # Conditional vit image | |
| cond_vit_image_slices, cond_vit_image_mask = self.parse_extra_token_pos( | |
| extra_token_pos, 'vit_img', full_seq_token_tensor) | |
| # All image slices (gen_image, joint_image) | |
| all_image_slices = [ | |
| slice(start, end + 1) | |
| for start, end in zip(extra_token_pos['<all_img>_start'], extra_token_pos['<all_img>_end']) | |
| ] if '<all_img>_start' in extra_token_pos and '<all_img>_end' in extra_token_pos else [] | |
| # Text mask | |
| text_slices = [ | |
| slice(start, end + 1) | |
| for start, end in zip(extra_token_pos['<text>_start'], extra_token_pos['<text>_end']) | |
| ] if '<text>_start' in extra_token_pos and '<text>_end' in extra_token_pos else [] | |
| assert len(text_slices) <= len(text_mask_specs), \ | |
| (f"Number of text slices ({len(text_slices)}) should be less than or equal to " | |
| f"number of text mask specs ({len(text_mask_specs)})") | |
| if use_text_mask: | |
| text_mask = torch.zeros_like(full_seq_token_tensor, dtype=torch.float32) | |
| for text_slice, mask_spec in zip(text_slices, text_mask_specs): | |
| if not mask_spec['ignore']: | |
| real_slice = slice( | |
| text_slice.start + mask_spec['start_offset'], | |
| text_slice.stop + mask_spec['end_offset'] | |
| ) | |
| text_mask[real_slice] = 1.0 | |
| else: | |
| text_mask = None | |
| # real_pos is the first position of the <pad> token | |
| real_pos = torch.tensor(extra_token_pos.get('first_pad', [full_seq_token_tensor.shape[0]]), dtype=torch.long) | |
| return TokenizerEncodeOutput( | |
| tokens=full_seq_token_tensor, | |
| timestep_scatter_index=timestep_scatter_index, | |
| guidance_scatter_index=guidance_scatter_index, | |
| text_slices=text_slices, | |
| gen_image_slices=gen_image_slices, | |
| joint_image_slices=joint_image_slices, | |
| cond_vae_image_slices=cond_vae_image_slices, | |
| cond_vit_image_slices=cond_vit_image_slices, | |
| text_mask=text_mask, | |
| gen_image_mask=gen_image_mask, | |
| cond_vae_image_mask=cond_vae_image_mask, | |
| cond_vit_image_mask=cond_vit_image_mask, | |
| real_pos=real_pos, | |
| all_image_slices=all_image_slices, | |
| cond_timestep_scatter_index=cond_timestep_scatter_index, | |
| gen_timestep_scatter_index=gen_timestep_scatter_index, | |
| ) | |
| def get_cot_sections(self, cot_text, uncond_kwargs, cot_max_length=None, drop_think=False): | |
| if not cot_text: # None or empty | |
| return [] | |
| if '<think>' in cot_text and '</think>' in cot_text: | |
| before_think_sec = cot_text.split('<think>')[0] | |
| after_think_sec = cot_text.split('</think>')[1] | |
| think_sec = cot_text.split('<think>')[1].split('</think>')[0] | |
| return self.get_cot_sections(before_think_sec, uncond_kwargs, drop_think=drop_think) + \ | |
| ([ | |
| dict(type="text", text="<think>"), | |
| dict(type="text", text=think_sec, max_length=cot_max_length, **uncond_kwargs), | |
| dict(type="text", text="</think>") | |
| ] if not drop_think else []) + \ | |
| self.get_cot_sections(after_think_sec, uncond_kwargs, drop_think=drop_think) | |
| if '<recaption>' in cot_text and '</recaption>' in cot_text: | |
| before_recaption_sec = cot_text.split('<recaption>')[0] | |
| after_recaption_sec = cot_text.split('</recaption>')[1] | |
| recaption_sec = cot_text.split('<recaption>')[1].split('</recaption>')[0] | |
| return self.get_cot_sections(before_recaption_sec, uncond_kwargs, drop_think=drop_think) + \ | |
| [ | |
| dict(type="text", text="<recaption>"), | |
| dict(type="text", text=recaption_sec, max_length=cot_max_length, **uncond_kwargs), | |
| dict(type="text", text="</recaption>") | |
| ] + \ | |
| self.get_cot_sections(after_recaption_sec, uncond_kwargs, drop_think=drop_think) | |
| return [ | |
| dict(type="text", text=cot_text, **uncond_kwargs), | |
| ] | |
| def apply_general_template( | |
| self, | |
| message_list, | |
| max_length=None, | |
| add_assistant_prefix=False, | |
| answer="auto", | |
| bot_task="auto", | |
| sequence_template="instruct", | |
| uncond_p=0.0, | |
| cfg_factor=1, | |
| batchify=False, | |
| image_base_size=1024, | |
| drop_think=False, | |
| ): | |
| # If cfg_factor > 1, we need to repeat the unconditioned part | |
| if batchify: | |
| assert isinstance(message_list[0], list), \ | |
| f"When batchify is True, message_list should be a list of list, but got [{type(message_list[0])}, ...]." | |
| return self.batch_gen_infer( | |
| infer_fn=self.apply_general_template, | |
| prompt_list=[[]], | |
| infer_fn_kwargs_list=[dict( | |
| message_list=message_list_i, | |
| max_length=max_length, | |
| add_assistant_prefix=add_assistant_prefix, | |
| answer=answer, | |
| bot_task=bot_task, | |
| sequence_template=sequence_template, | |
| image_base_size=image_base_size, | |
| drop_think=drop_think, | |
| ) for message_list_i in message_list], | |
| do_classifier_free_guidance=cfg_factor > 1, | |
| condition_repeat_times=1, | |
| uncondition_repeat_times=cfg_factor - 1, | |
| ) | |
| conv = Conversation() | |
| uncond_kwargs = dict(uncond_enabled=uncond_p == 1.0, uncond_p=uncond_p) | |
| def process_successive_message(_message_list, _cur_message_idx, role, prefix, suffix, | |
| answer_prefix="", answer_suffix=""): | |
| _sub_sections = [] | |
| while _cur_message_idx < len(message_list) and _message_list[_cur_message_idx]['role'] == role: | |
| message = _message_list[_cur_message_idx] | |
| if message['type'] == 'text': | |
| text = message['content'] | |
| if role == "system": | |
| _sub_sections.append(dict(type="text", text=text)) | |
| elif role == "assistant": | |
| if ("<recaption>" in text and "</recaption>" in text) or ( | |
| "<think>" in text and "</think>" in text): | |
| _sub_sections.extend(self.get_cot_sections(text, uncond_kwargs, drop_think=drop_think)) | |
| else: | |
| _sub_sections.append(dict(type="text", text=text, **uncond_kwargs)) | |
| else: | |
| _sub_sections.append(dict( | |
| type="text", text=f"{answer_prefix}{text}{answer_suffix}", **uncond_kwargs)) | |
| elif message['type'] == 'gen_image': | |
| info = message['content'] | |
| assert isinstance(info, ImageInfo), f"Expected ImageInfo, but got {type(info)}" | |
| if role == "assistant": | |
| _sub_sections.append(dict(type="text", text=answer_prefix)) | |
| _sub_sections.append(dict(type=message['type'], **info.meta_info)) | |
| if role == "assistant": | |
| _sub_sections.append(dict(type="text", text=answer_suffix)) | |
| elif message['type'] == 'joint_image': | |
| info = message['content'] | |
| assert isinstance(info, JointImageInfo), f"Expected JointImageInfo, but got {type(info)}" | |
| _sub_sections.append(dict(type=message['type'], **info.meta_info)) | |
| else: | |
| raise ValueError(f"Unknown message type: {message['type']}") | |
| _cur_message_idx += 1 | |
| if len(_sub_sections) > 0: | |
| # Add role prefix and suffix | |
| _sub_sections.insert(0, dict(type='text', text=prefix)) | |
| _sub_sections.append(dict(type='text', text=suffix)) | |
| return _sub_sections, _cur_message_idx | |
| # Define assistant prefix and suffix | |
| if (answer == "auto" and sequence_template == "instruct") or answer is True: | |
| answer_prefix, answer_suffix = "<answer>", "</answer>" | |
| else: | |
| answer_prefix, answer_suffix = "", "" | |
| if sequence_template == "pretrain": | |
| system_suffix = "" | |
| user_prefix = "" | |
| user_suffix = "" | |
| bot_prefix = "" | |
| bot_suffix = "" | |
| else: | |
| system_suffix = f"{conv.sep}" | |
| user_prefix = f"{conv.roles[0]}: " | |
| user_suffix = f"{conv.sep}" | |
| bot_prefix = f"{conv.roles[1]}: " | |
| bot_suffix = f"{conv.sep}" | |
| # Process successive user and assistant messages | |
| sections = [] | |
| cur_message_idx = 0 | |
| final_role = None | |
| while cur_message_idx < len(message_list): | |
| # Process successive system messages | |
| sub_sections, cur_message_idx = process_successive_message( | |
| message_list, cur_message_idx, role="system", prefix="", suffix=system_suffix) | |
| # Add to the template and sections | |
| sections.extend(sub_sections) | |
| if len(sub_sections) > 0: | |
| final_role = "system" | |
| # Process successive user messages | |
| sub_sections, cur_message_idx = process_successive_message( | |
| message_list, cur_message_idx, role="user", prefix=user_prefix, suffix=user_suffix) | |
| # Add to the template and sections | |
| sections.extend(sub_sections) | |
| if len(sub_sections) > 0: | |
| final_role = "user" | |
| # Process successive assistant messages | |
| sub_sections, cur_message_idx = process_successive_message( | |
| message_list, cur_message_idx, role="assistant", prefix=bot_prefix, suffix=bot_suffix, | |
| answer_prefix=answer_prefix, answer_suffix=answer_suffix, | |
| ) | |
| # Add to the template and sections | |
| sections.extend(sub_sections) | |
| if len(sub_sections) > 0: | |
| final_role = "assistant" | |
| if add_assistant_prefix: | |
| if final_role == "assistant": | |
| # Avoid adding prefix twice | |
| _bot_prefix = "" | |
| # Remove the final bot_suffix | |
| if len(sections) > 0 and sections[-1]['type'] == 'text' and sections[-1]['text'] == bot_suffix: | |
| sections = sections[:-1] | |
| else: | |
| _bot_prefix = bot_prefix | |
| # We can add special tokens for the bot lastest message according to different tasks | |
| bot_response_prefix = dict( | |
| auto=_bot_prefix, | |
| image="", | |
| think=f"{_bot_prefix}<think>", | |
| recaption=f"{_bot_prefix}<recaption>", | |
| img_ratio=f"{_bot_prefix}{answer_prefix}<boi><img_size_{image_base_size}>", | |
| )[bot_task] | |
| sections.append(dict(type='text', text=bot_response_prefix)) | |
| output = self.encode_general( | |
| sections=sections, | |
| use_text_mask=False, | |
| add_eos=False, | |
| add_pad=False, | |
| ) | |
| if max_length is not None: | |
| if output.tokens.shape[-1] > max_length: | |
| raise ValueError( | |
| f"Encoded token length {output.tokens.shape[-1]} exceeds max_length {max_length}.\n" | |
| f"Please set a larger max_length or check the input messages:\n{message_list}" | |
| ) | |
| return output, sections | |
| def apply_chat_template( | |
| self, | |
| batch_prompt: Optional[List[str]] = None, | |
| batch_message_list: Optional[List[List[Dict[str, Any]]]] = None, | |
| mode: str = "gen_text", | |
| batch_gen_image_info: Optional[List[ImageInfo]] = None, | |
| batch_cond_image_info: Optional[Union[List[JointImageInfo], List[List[JointImageInfo]]]] = None, | |
| batch_system_prompt: Optional[List[str]] = None, | |
| batch_cot_text: Optional[List[str]] = None, | |
| max_length: Optional[int] = None, | |
| bot_task: str = "auto", # auto/image/think/recaption/img_ratio | |
| image_base_size: int = 1024, | |
| sequence_template: str = "pretrain", | |
| cfg_factor: int = 1, | |
| add_assistant_prefix: Optional[bool] = None, | |
| drop_think: bool = False, | |
| ) -> Dict[str, Any]: | |
| assert bot_task in ["image", "auto", "think", "recaption", "img_ratio"], \ | |
| f"bot_task should be one of ['image', 'auto', 'think', 'recaption', 'img_ratio'], but got {bot_task}." | |
| if batch_message_list is None: | |
| # Simple text-to-image or text-cot-to-image task | |
| batch_size = len(batch_prompt) | |
| # Batchify inputs | |
| if not isinstance(batch_system_prompt, list): | |
| batch_system_prompt = [batch_system_prompt] * batch_size | |
| if not isinstance(batch_gen_image_info, list): | |
| batch_gen_image_info = [batch_gen_image_info] * batch_size | |
| if batch_cot_text is not None: | |
| assert len(batch_cot_text) == batch_size, \ | |
| (f"batch_cot_text should have the same length as batch_size ({batch_size}), " | |
| f"but got {len(batch_cot_text)}.") | |
| else: | |
| batch_cot_text = [None] * batch_size | |
| if batch_cond_image_info is not None: | |
| assert len(batch_cond_image_info) == batch_size, \ | |
| (f"batch_cond_image_info should have the same length as batch_size ({batch_size}), " | |
| f"but got {len(batch_cond_image_info)}.") | |
| batch_cond_image_info = [ | |
| cond_image_info if isinstance(cond_image_info, list) else [cond_image_info] | |
| for cond_image_info in batch_cond_image_info | |
| ] | |
| else: | |
| batch_cond_image_info = [[] for _ in range(batch_size)] | |
| # Convert single round materials into standard message list | |
| batch_message_list = [] | |
| for prompt, system_prompt, cot_text, gen_image_info, cond_image_info_list in zip( | |
| batch_prompt, batch_system_prompt, batch_cot_text, batch_gen_image_info, | |
| batch_cond_image_info, | |
| ): | |
| message_list = [] | |
| # 1. system prompt section | |
| if system_prompt: | |
| message_list.append(dict( | |
| role="system", type="text", content=system_prompt, context_type="str")) | |
| # 2. user inputs sections | |
| # 2.1 image inputs | |
| if len(cond_image_info_list) > 0: | |
| message_list.extend([ | |
| dict(role="user", type="joint_image", content=cond_image_info, context_type="image_info") | |
| for cond_image_info in cond_image_info_list | |
| ]) | |
| # 2.2 text inputs | |
| message_list.append(dict( | |
| role="user", type="text", content=prompt, context_type="str")) | |
| # 3. assistant answer sections | |
| if cot_text is not None: | |
| message_list.append(dict(role="assistant", type="text", content=cot_text, context_type="str")) | |
| if mode == "gen_image": | |
| message_list.append(dict( | |
| role="assistant", type="gen_image", content=gen_image_info, context_type="image_info")) | |
| # --- | |
| batch_message_list.append(message_list) | |
| output, sections = self.apply_general_template( | |
| message_list=batch_message_list, | |
| max_length=max_length, | |
| add_assistant_prefix=default(add_assistant_prefix, mode != "gen_image"), | |
| bot_task=bot_task, | |
| sequence_template=sequence_template, | |
| cfg_factor=cfg_factor, | |
| batchify=True, | |
| image_base_size=image_base_size, | |
| drop_think=drop_think, | |
| ) | |
| return dict(output=output, sections=sections) | |