mol-instruct-base-llava-untuned / processing_llava.py
Neroism8422's picture
This is the basic llava version of the original mol-instruct model, none tuned with only vision encoder of CLIP add on.
aea55e2 verified
from transformers import CLIPImageProcessor, AutoTokenizer
from transformers.processing_utils import ProcessorMixin
from transformers.image_utils import ImageInput
from typing import List, Optional, Union
import torch
from PIL import Image
class LlavaProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "CLIPImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
super().__init__(image_processor, tokenizer)
# Add all required attributes for LlamaFactory compatibility
self.patch_size = getattr(image_processor, 'patch_size', 14)
self.vision_feature_select_strategy = "default"
self.num_additional_image_tokens = 0 # No additional tokens beyond the image patches
def __call__(
self,
text: Union[str, List[str]] = None,
images: ImageInput = None,
**kwargs
):
if images is not None:
if isinstance(images, (str, Image.Image)):
images = [images]
# Process images and get sizes
image_inputs = self.image_processor(images, return_tensors="pt", **kwargs)
# Add image_sizes for LlamaFactory compatibility
if "pixel_values" in image_inputs:
batch_size = image_inputs["pixel_values"].shape[0]
# Default to image processor size
image_size = getattr(self.image_processor, "size", {"height": 336, "width": 336})
if isinstance(image_size, dict):
height = image_size.get("height", 336)
width = image_size.get("width", 336)
else:
height = width = image_size
image_inputs["image_sizes"] = torch.tensor([[height, width]] * batch_size)
else:
image_inputs = {}
if text is not None:
text_inputs = self.tokenizer(text, **kwargs)
return {**text_inputs, **image_inputs}
return image_inputs
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)