| |
|
|
| import json |
| import os |
| from pathlib import Path |
| from typing import Union, List |
| import urllib |
|
|
| import torch |
| from torchvision.transforms import Compose, ToTensor, Normalize, Resize, InterpolationMode |
| from tqdm import tqdm |
|
|
| from clip import _tokenizer |
| from clip.model import convert_weights, CLIP, restore_model |
|
|
| __all__ = ["load", "tokenize", "available_models", "image_transform", "load_from_name"] |
|
|
| _MODELS = { |
| "ViT-B-16": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-base.pt", |
| "ViT-L-14": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-large.pt", |
| "RN50": "https://huggingface.co/TencentARC/QA-CLIP/resolve/main/QA-CLIP-RN50.pt", |
| } |
| _MODEL_INFO = { |
| "ViT-B-16": { |
| "struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese", |
| "input_resolution": 224 |
| }, |
| "ViT-L-14": { |
| "struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese", |
| "input_resolution": 224 |
| }, |
| "RN50": { |
| "struct": "RN50@RBT3-chinese", |
| "input_resolution": 224 |
| }, |
| } |
|
|
|
|
| def _download(url: str, root: str): |
| os.makedirs(root, exist_ok=True) |
| filename = os.path.basename(url) |
|
|
| download_target = os.path.join(root, filename) |
|
|
| if os.path.exists(download_target) and not os.path.isfile(download_target): |
| raise RuntimeError(f"{download_target} exists and is not a regular file") |
|
|
| if os.path.isfile(download_target): |
| return download_target |
|
|
| with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
| with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, |
| unit_divisor=1024) as loop: |
| while True: |
| buffer = source.read(8192) |
| if not buffer: |
| break |
|
|
| output.write(buffer) |
| loop.update(len(buffer)) |
|
|
| return download_target |
|
|
|
|
| def _convert_image_to_rgb(image): |
| return image.convert("RGB") |
|
|
|
|
| def available_models() -> List[str]: |
| """Returns the names of available CLIP models""" |
| return list(_MODELS.keys()) |
|
|
|
|
| def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", |
| download_root: str = None, vision_model_name: str = None, text_model_name: str = None, input_resolution: int = None): |
| if name in _MODELS: |
| model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) |
| model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution'] |
| elif os.path.isfile(name): |
| assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'" |
| model_path = name |
| model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution |
| else: |
| raise RuntimeError(f"Model {name} not found; available models = {available_models()}") |
|
|
| with open(model_path, 'rb') as opened_file: |
| |
| checkpoint = torch.load(opened_file, map_location="cpu") |
|
|
| model = create_model(model_name, checkpoint) |
| if str(device) == "cpu": |
| model.float() |
| else: |
| model.to(device) |
| return model, image_transform(model_input_resolution) |
|
|
|
|
| def load(model, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", clip_path=None, |
| bert_path=None, use_flash_attention=False): |
| """Load CLIP and BERT model weights |
| """ |
|
|
| bert_state_dict = torch.load(bert_path, map_location="cpu") if bert_path else None |
| clip_state_dict = torch.load(clip_path, map_location="cpu") if clip_path else None |
|
|
| restore_model(model, clip_state_dict, bert_state_dict, use_flash_attention).to(device) |
|
|
| if str(device) == "cpu": |
| model.float() |
| return model |
|
|
|
|
| def tokenize(texts: Union[str, List[str]], context_length: int = 52) -> torch.LongTensor: |
| """ |
| Returns the tokenized representation of given input string(s) |
| Parameters |
| ---------- |
| texts : Union[str, List[str]] |
| An input string or a list of input strings to tokenize |
| context_length : int |
| The context length to use; all baseline models use 52 as the context length |
| Returns |
| ------- |
| A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] |
| """ |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| all_tokens = [] |
| for text in texts: |
| all_tokens.append([_tokenizer.vocab['[CLS]']] + _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(text))[ |
| :context_length - 2] + [_tokenizer.vocab['[SEP]']]) |
|
|
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
|
|
| for i, tokens in enumerate(all_tokens): |
| assert len(tokens) <= context_length |
| result[i, :len(tokens)] = torch.tensor(tokens) |
|
|
| return result |
|
|
|
|
| def _convert_to_rgb(image): |
| return image.convert('RGB') |
|
|
|
|
| def image_transform(image_size=224): |
| transform = Compose([ |
| Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), |
| _convert_to_rgb, |
| ToTensor(), |
| Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
| ]) |
| return transform |
|
|
|
|
| def create_model(model_name, checkpoint=None): |
| vision_model, text_model = model_name.split('@') |
| |
| vision_model_config_file = Path( |
| __file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json" |
| print('Loading vision model config from', vision_model_config_file) |
| assert os.path.exists(vision_model_config_file) |
|
|
| text_model_config_file = Path( |
| __file__).parent / f"model_configs/{text_model.replace('/', '-')}.json" |
| print('Loading text model config from', text_model_config_file) |
| assert os.path.exists(text_model_config_file) |
|
|
| with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft: |
| model_info = json.load(fv) |
| for k, v in json.load(ft).items(): |
| model_info[k] = v |
| if isinstance(model_info['vision_layers'], str): |
| model_info['vision_layers'] = eval(model_info['vision_layers']) |
| print('Model info', model_info) |
| model = CLIP(**model_info) |
| convert_weights(model) |
| if checkpoint: |
| sd = checkpoint["state_dict"] |
| if next(iter(sd.items()))[0].startswith('module'): |
| sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k} |
| model.load_state_dict(sd) |
| return model |
|
|