Instructions to use ibm-granite/granite-vision-3.3-2b-embedding with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ibm-granite/granite-vision-3.3-2b-embedding with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="ibm-granite/granite-vision-3.3-2b-embedding", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ibm-granite/granite-vision-3.3-2b-embedding", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import ClassVar, Optional | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from transformers import LlavaNextPreTrainedModel | |
| from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration | |
| from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape | |
| from .colgranitevision_config import ColGraniteVisionConfig | |
| class LlavaNextWithCustomPacking(LlavaNextForConditionalGeneration): | |
| def pack_image_features( | |
| self, | |
| image_features, | |
| image_sizes, | |
| vision_feature_select_strategy, | |
| image_newline=None, | |
| base_image_feature_location="last", | |
| ): | |
| """ | |
| Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. | |
| Args: | |
| image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`) | |
| List of image feature tensor, each contains all the visual feature of all patches. | |
| image_sizes (`torch.Tensor` of shape `(num_images, 2)`) | |
| Actual image size of each images (H, W). | |
| vision_feature_select_strategy (`str`) | |
| The feature selection strategy used to select the vision feature from the vision backbone. | |
| image_newline (`torch.Tensor` of shape `(embed_dim)`) | |
| New line embedding vector. | |
| Returns: | |
| image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) | |
| feature_lens (`List[int]`) | |
| token length of each image in image_features | |
| """ | |
| new_image_features = [] | |
| feature_lens = [] | |
| for image_idx, image_feature in enumerate(image_features): | |
| if image_feature.shape[0] > 1: | |
| base_image_feature = image_feature[0] | |
| image_feature = image_feature[1:] | |
| height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size | |
| num_patch_height, num_patch_width = get_anyres_image_grid_shape( | |
| image_sizes[image_idx], | |
| self.config.image_grid_pinpoints, | |
| self.config.vision_config.image_size, | |
| ) | |
| if ( | |
| np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0 | |
| and vision_feature_select_strategy == "default" | |
| ): | |
| print( | |
| "Image feature shape does not line up with the provided patch size. " | |
| "You may be using the `default` vision_feature_select_strategy with a" | |
| " visual encoder that does not have CLS." | |
| ) | |
| image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) | |
| image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() | |
| image_feature = image_feature.flatten(1, 2).flatten(2, 3) | |
| image_feature = unpad_image(image_feature, image_sizes[image_idx]) | |
| if image_newline is not None: | |
| image_feature = torch.cat( | |
| ( | |
| image_feature, | |
| image_newline[:, None, None] | |
| .expand(*image_feature.shape[:-1], 1) | |
| .to(image_feature.device, image_feature.dtype), | |
| ), | |
| dim=-1, | |
| ) | |
| image_feature = image_feature.flatten(1, 2).transpose(0, 1) | |
| if base_image_feature_location == "last": | |
| image_feature = torch.cat((image_feature, base_image_feature), dim=0) | |
| else: | |
| image_feature = torch.cat((base_image_feature, image_feature), dim=0) | |
| else: | |
| image_feature = image_feature[0] | |
| if image_newline is not None: | |
| image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) | |
| new_image_features.append(image_feature) | |
| feature_lens.append(image_feature.size(0)) | |
| image_features = torch.cat(new_image_features, dim=0) | |
| feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) | |
| return image_features, feature_lens | |
| class ColGraniteVision(LlavaNextPreTrainedModel): | |
| """ | |
| ColGraniteVision model implementation. | |
| """ | |
| main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related | |
| config_class = ColGraniteVisionConfig | |
| def __init__(self, config: ColGraniteVisionConfig): | |
| super().__init__(config=config) | |
| model = LlavaNextWithCustomPacking(config=config) | |
| if model.language_model._tied_weights_keys is not None: | |
| self._tied_weights_keys = [f"model.language_model.{k}" for k in model.language_model._tied_weights_keys] | |
| self.model = model | |
| # TODO: Wait for ColPali2 to create a ColPaliConfig to allow specifying the embedding dimension. | |
| # We could do it now but it would break all the models trying to load the model from the checkpoint. | |
| self.dim = 128 | |
| self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim) | |
| self.post_init() | |
| def forward(self, *args, **kwargs) -> torch.Tensor: | |
| # Delete output_hidden_states from kwargs | |
| kwargs.pop("output_hidden_states", None) | |
| if "pixel_values" in kwargs: | |
| kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype) | |
| outputs = self.model(*args, output_hidden_states=True, **kwargs) # (batch_size, sequence_length, hidden_size) | |
| last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size) | |
| attention_mask = kwargs["attention_mask"] | |
| if "pixel_values" in kwargs: | |
| input_ids = kwargs['input_ids'] | |
| image_mask = (input_ids == self.config.image_token_index) | |
| # inputs_embeds = last_hidden_states.masked_scatter(image_mask) | |
| N, M = image_mask.shape | |
| # Create an index matrix: each row is 0, 1, ..., M-1 | |
| idx = torch.arange(M, device=image_mask.device).expand(N, M) | |
| # Replace False positions with -1 so they are ignored by topk (since all valid indices are >=0) | |
| masked_idx = torch.where(image_mask, idx, torch.tensor(-1, device=image_mask.device)) | |
| topk_values, _ = torch.topk(masked_idx, k=729, dim=1) | |
| last_k_indices, _ = torch.sort(topk_values, dim=1) | |
| last_k_indices_exp = last_k_indices.unsqueeze(-1).expand(-1, -1, last_hidden_states.size(-1)) | |
| last_hidden_states = torch.gather(last_hidden_states, 1, last_k_indices_exp) | |
| attention_mask = torch.gather(attention_mask, 1, last_k_indices) | |
| attention_mask = attention_mask.unsqueeze(-1) | |
| proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim) | |
| # L2 normalization | |
| proj = proj / (proj.norm(dim=-1, keepdim=True) + 1e-8) | |
| # proj = proj * kwargs["attention_mask"].unsqueeze(-1) # (batch_size, sequence_length, dim) | |
| proj = proj * attention_mask # (batch_size, sequence_length, dim) | |
| return proj | |
| def get_input_embeddings(self): | |
| return self.model.language_model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.model.language_model.set_input_embeddings(value) | |
| def get_output_embeddings(self): | |
| return self.model.language_model.get_output_embeddings() | |
| def set_output_embeddings(self, new_embeddings): | |
| self.model.language_model.set_output_embeddings(new_embeddings) | |
| def set_decoder(self, decoder): | |
| self.model.language_model.set_decoder(decoder) | |
| def get_decoder(self): | |
| return self.model.language_model.get_decoder() | |
| def tie_weights(self): | |
| return self.model.language_model.tie_weights() | |
| def resize_token_embeddings( | |
| self, | |
| new_num_tokens: Optional[int] = None, | |
| pad_to_multiple_of=None, | |
| ) -> nn.Embedding: | |
| model_embeds = self.model.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) | |
| # Update vocab size | |
| self.config.text_config.vocab_size = model_embeds.num_embeddings | |
| self.config.vocab_size = model_embeds.num_embeddings | |
| self.model.vocab_size = model_embeds.num_embeddings | |
| return model_embeds | |
| def patch_size(self) -> int: | |
| return self.model.vision_tower.config.patch_size | |