This is the basic llava version of the original mol-instruct model, none tuned with only vision encoder of CLIP add on.
aea55e2 verified | from transformers import CLIPImageProcessor, AutoTokenizer | |
| from transformers.processing_utils import ProcessorMixin | |
| from transformers.image_utils import ImageInput | |
| from typing import List, Optional, Union | |
| import torch | |
| from PIL import Image | |
| class LlavaProcessor(ProcessorMixin): | |
| attributes = ["image_processor", "tokenizer"] | |
| image_processor_class = "CLIPImageProcessor" | |
| tokenizer_class = "AutoTokenizer" | |
| def __init__(self, image_processor=None, tokenizer=None, **kwargs): | |
| super().__init__(image_processor, tokenizer) | |
| # Add all required attributes for LlamaFactory compatibility | |
| self.patch_size = getattr(image_processor, 'patch_size', 14) | |
| self.vision_feature_select_strategy = "default" | |
| self.num_additional_image_tokens = 0 # No additional tokens beyond the image patches | |
| def __call__( | |
| self, | |
| text: Union[str, List[str]] = None, | |
| images: ImageInput = None, | |
| **kwargs | |
| ): | |
| if images is not None: | |
| if isinstance(images, (str, Image.Image)): | |
| images = [images] | |
| # Process images and get sizes | |
| image_inputs = self.image_processor(images, return_tensors="pt", **kwargs) | |
| # Add image_sizes for LlamaFactory compatibility | |
| if "pixel_values" in image_inputs: | |
| batch_size = image_inputs["pixel_values"].shape[0] | |
| # Default to image processor size | |
| image_size = getattr(self.image_processor, "size", {"height": 336, "width": 336}) | |
| if isinstance(image_size, dict): | |
| height = image_size.get("height", 336) | |
| width = image_size.get("width", 336) | |
| else: | |
| height = width = image_size | |
| image_inputs["image_sizes"] = torch.tensor([[height, width]] * batch_size) | |
| else: | |
| image_inputs = {} | |
| if text is not None: | |
| text_inputs = self.tokenizer(text, **kwargs) | |
| return {**text_inputs, **image_inputs} | |
| return image_inputs | |
| def batch_decode(self, *args, **kwargs): | |
| return self.tokenizer.batch_decode(*args, **kwargs) | |
| def decode(self, *args, **kwargs): | |
| return self.tokenizer.decode(*args, **kwargs) | |