| from PIL import ImageOps |
| from PIL.Image import Image |
|
|
| import torch |
|
|
| from typing import Union, List |
| from tqdm import tqdm |
|
|
| from transformers.image_utils import ImageInput |
| from transformers.tokenization_utils_base import TextInput |
| from transformers import CLIPImageProcessor |
| from transformers.processing_utils import ( |
| ProcessorMixin, |
| ) |
| from transformers import AutoTokenizer, PreTrainedTokenizer |
|
|
| from .image_processing_instellavl import InstellaVLImageProcessor |
| from .mm_utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, KeywordsStoppingCriteria |
| from .conversation import conv_templates |
|
|
| def tokenizer_image_token(prompt: str, tokenizer: PreTrainedTokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None)->Union[torch.Tensor, List[torch.Tensor]]: |
| r""" |
| Tokenizes a prompt containing image tokens and inserts the specified image token index at the appropriate positions. |
| |
| Args: |
| - prompt (str): The input prompt string containing text and DEFAULT_IMAGE_TOKEN="<image>" placeholders. |
| - tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the text chunks. |
| - image_token_index (int): The token index to use for the image placeholders. Default is IMAGE_TOKEN_INDEX. |
| - return_tensors (str, optional): The type of tensor to return. If "pt", returns a PyTorch tensor. Default is None. |
| |
| Returns: |
| list or torch.Tensor: The tokenized input IDs as a list or a PyTorch tensor if return_tensors is specified. |
| """ |
| prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)] |
|
|
| def insert_separator(X, sep): |
| return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] |
|
|
| input_ids = [] |
| offset = 0 |
| if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
| offset = 1 |
| input_ids.append(prompt_chunks[0][0]) |
|
|
| for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
| input_ids.extend(x[offset:]) |
|
|
| if return_tensors is not None: |
| if return_tensors == "pt": |
| return torch.tensor(input_ids, dtype=torch.long) |
| raise ValueError(f"Unsupported tensor type: {return_tensors}") |
| return input_ids |
|
|
|
|
| class InstellaVLProcessor(ProcessorMixin): |
| attributes = ["image_processor", "tokenizer"] |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = ("GPTNeoXTokenizerFast") |
|
|
| def __init__(self, image_processor: InstellaVLImageProcessor = None, tokenizer: AutoTokenizer = None, **kwargs): |
| super().__init__(image_processor, tokenizer, **kwargs) |
| |
| def pad_sequence(self, input_ids: Union[List[torch.Tensor], List[List[torch.Tensor]]], batch_first: bool, padding_value: int, tokenizer: AutoTokenizer): |
| if tokenizer.padding_side == "left": |
| input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] |
| input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) |
| if tokenizer.padding_side == "left": |
| input_ids = torch.flip(input_ids, [1]) |
| return input_ids |
| |
| def encode(self, |
| text: TextInput = None, |
| images: ImageInput = None, |
| image_processor: CLIPImageProcessor = None, |
| tokenizer: AutoTokenizer = None, |
| model_cfg: dict = None, |
| ) -> dict: |
|
|
| if images is not None: |
| if isinstance(images, Image): |
| |
| |
| ImageOps.exif_transpose(images, in_place=True) |
| image_sizes = [images.size] |
| images = [images] |
| elif isinstance(images, list): |
| image_sizes = [] |
| for i in images: |
| ImageOps.exif_transpose(i, in_place=True) |
| image_sizes.append(i.size) |
| image_tensor = self.image_processor.process(images, image_processor, model_cfg)['pixel_values'] |
|
|
| text = text.replace(DEFAULT_IMAGE_TOKEN, "").strip() |
| if images is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in text: |
| question = DEFAULT_IMAGE_TOKEN + "\n" + text |
| else: |
| question = text |
| conv = conv_templates["instella"].copy() |
| conv.append_message(conv.roles[0], question) |
| conv.append_message(conv.roles[1], None) |
| prompt_question = conv.get_prompt() |
|
|
|
|
| input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0) |
| keywords = [conv.sep] |
| stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
| terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("|||IP_ADDRESS|||")] |
|
|
| out = { |
| "input_ids": input_ids, |
| "stopping_criteria": [stopping_criteria], |
| "eos_token_id": terminators, |
| } |
| if images is not None: |
| out = { |
| "image_tensor": image_tensor, |
| "image_sizes": image_sizes, |
| **out, |
| } |
| self.tokenizer = tokenizer |
| return out |
|
|
| def batch_encode(self, |
| texts: List[TextInput] = None, |
| images: List[ImageInput] = None, |
| image_processor: CLIPImageProcessor = None, |
| tokenizer: AutoTokenizer = None, |
| model_cfg: dict = None, |
| ): |
| |
| if texts is None: |
| raise ValueError("Text must be provided for batch encoding.") |
|
|
| if images is None: |
| images = [None] * len(text) |
|
|
| assert isinstance(texts, list), "Since batch encoding happening, provide batch of texts in a list." |
|
|
| assert len(texts) == len(images), "The number of texts and images must be equal." |
|
|
| batch_outs = [] |
| for txt, img in tqdm(zip(texts, images), total=len(texts), desc="Total Samples to encode"): |
| batch_outs.append(self.encode(txt, img, image_processor, tokenizer, model_cfg)) |
|
|
| return batch_outs |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def decode(self, output_ids: torch.Tensor)->str: |
| return self.tokenizer.decode(output_ids[0, :], skip_special_tokens=True).strip() |
|
|
| def batch_decode(self, output_ids_lst: List[torch.Tensor])->List[str]: |
| raise NotImplementedError("Batch decode is not implemented for InstellaVLProcessor") |
| |
| |
| |
| |
| |
|
|
| |
| InstellaVLProcessor.register_for_auto_class() |
|
|