| from transformers import PreTrainedModel, ViTMAEModel |
| from .configuration_magiv2 import Magiv2Config |
| import torch |
| import numpy as np |
| from transformers import ViTImageProcessor |
| import PIL |
|
|
| def move_to_device(inputs, device): |
| if hasattr(inputs, "keys"): |
| return {k: move_to_device(v, device) for k, v in inputs.items()} |
| elif isinstance(inputs, list): |
| return [move_to_device(v, device) for v in inputs] |
| elif isinstance(inputs, tuple): |
| return tuple([move_to_device(v, device) for v in inputs]) |
| elif isinstance(inputs, np.ndarray): |
| return torch.from_numpy(inputs).to(device) |
| else: |
| return inputs.to(device) |
|
|
| class Magiv2Model(PreTrainedModel): |
| config_class = Magiv2Config |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.processor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config) |
| self.crop_embedding_model = ViTMAEModel(config.crop_embedding_model_config) |
|
|
| def move_to_device(self, input): |
| return move_to_device(input, self.device) |
| |
| def forward(self, images, move_to_device_fn=None, mask_ratio=0.0, batch_size=256, convert_to_grayscale=True): |
| move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn |
| if len(images) == 0: |
| return move_to_device_fn(torch.zeros(len(images), self.config.crop_embedding_model_config.hidden_size)) |
|
|
| assert all(isinstance(image, PIL.Image.Image) for image in images), "please provide a list of PIL images" |
| if convert_to_grayscale: |
| images = [x.convert("L") for x in images] |
| images = [np.array(image.convert("RGB")) for image in images] |
| images = self.processor(images, return_tensors="pt").pixel_values |
| images = move_to_device_fn(images) |
| images = images.to(self.dtype) |
|
|
| |
| old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio |
| self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio |
| |
| |
| embeddings = [] |
| for i in range(0, len(images), batch_size): |
| crops = images[i:i+batch_size] |
| embeddings_per_batch = self.crop_embedding_model(crops).last_hidden_state[:, 0] |
| embeddings.append(embeddings_per_batch) |
| embeddings = torch.cat(embeddings, dim=0) |
| |
| |
| self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio |
|
|
| return embeddings |