|
|
| import math |
| from typing import List, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision.models import resnet50, ResNet |
|
|
| from .clip import load, tokenize |
| from .simple_tokenizer import SimpleTokenizer as _Tokenizer |
| from data.imagnet_prompts import imagenet_classes |
| from data.fewshot_datasets import fewshot_datasets |
| from data.cls_to_names import * |
| |
| import os |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| _tokenizer = _Tokenizer() |
|
|
| DOWNLOAD_ROOT='~/.cache/clip' |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
|
|
| class TextEncoder(nn.Module): |
| def __init__(self, medclip_text_model): |
| super().__init__() |
| self.medclip_text_model = medclip_text_model |
| |
| def forward(self, prompts_embeddings, tokenized_prompts): |
|
|
| output = self.medclip_text_model.model(inputs_embeds=prompts_embeddings, attention_mask=tokenized_prompts['attention_mask']) |
|
|
| |
| |
| |
| |
|
|
| |
| last_hidden_states = torch.stack([output['hidden_states'][1], output['hidden_states'][2], output['hidden_states'][-1]]) |
| embed = last_hidden_states.permute(1,0,2,3).mean(2).mean(1) |
|
|
| |
| |
|
|
| embed = self.medclip_text_model.projection_head(embed) |
| return embed |
|
|
|
|
| class PromptLearner(nn.Module): |
| def __init__(self, medclip_model, classnames, device, batch_size=None, n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False): |
| super().__init__() |
| n_cls = len(classnames) |
| self.learned_cls = learned_cls |
| dtype = medclip_model.dtype |
| self.dtype = dtype |
| ctx_dim = 768 |
| self.ctx_dim = ctx_dim |
| self.batch_size = batch_size |
| self.device = device |
| self.medclip_model = medclip_model |
|
|
| |
|
|
| if ctx_init: |
| |
| |
| print("Initializing the contect with given words: [{}]".format(ctx_init)) |
| |
| ctx_init = ctx_init.replace("_", " ") |
| if '[CLS]' in ctx_init: |
| ctx_list = ctx_init.split(" ") |
| split_idx = ctx_list.index("[CLS]") |
| ctx_init = ctx_init.replace("[CLS] ", "") |
| ctx_position = "middle" |
| else: |
| split_idx = None |
| self.split_idx = split_idx |
| n_ctx = len(ctx_init.split(" ")) |
| |
| |
| prompt = ctx_init |
| tokenized_prompts = medclip_model.text_model.tokenizer(prompt, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device) |
| prompts_tokens = tokenized_prompts['input_ids'] |
| with torch.no_grad(): |
| embedding = medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(dtype) |
| |
| ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] |
| prompt_prefix = ctx_init |
| else: |
| print("Random initialization: initializing a generic context") |
| ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) |
| nn.init.normal_(ctx_vectors, std=0.02) |
| prompt_prefix = " ".join(["X"] * n_ctx) |
| |
| self.prompt_prefix = prompt_prefix |
|
|
| print(f'Initial context: "{prompt_prefix}"') |
| print(f"Number of context words (tokens): {n_ctx}") |
|
|
| |
| if self.batch_size is not None: |
| ctx_vectors = ctx_vectors.repeat(batch_size, 1, 1) |
| self.ctx_init_state = ctx_vectors.detach().clone() |
| self.ctx = nn.Parameter(ctx_vectors) |
|
|
| if not self.learned_cls: |
| classnames = [name.replace("_", " ") for name in classnames] |
| name_lens = [len(medclip_model.text_model.tokenizer.encode(name))-2 for name in classnames] |
| prompts = [prompt_prefix + " " + name + "." for name in classnames] |
| else: |
| print("Random initialization: initializing a learnable class token") |
| cls_vectors = torch.empty(n_cls, 1, ctx_dim, dtype=dtype) |
| nn.init.normal_(cls_vectors, std=0.02) |
| cls_token = "X" |
| name_lens = [1 for _ in classnames] |
| prompts = [prompt_prefix + " " + cls_token + "." for _ in classnames] |
|
|
| self.cls_init_state = cls_vectors.detach().clone() |
| self.cls = nn.Parameter(cls_vectors) |
|
|
| tokenized_prompts = medclip_model.text_model.tokenizer(prompts, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device) |
| prompts_tokens = tokenized_prompts['input_ids'] |
| with torch.no_grad(): |
| embedding = medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(dtype) |
|
|
| |
| |
| |
| self.register_buffer("token_prefix", embedding[:, :1, :]) |
| if self.learned_cls: |
| self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + 1:, :]) |
| else: |
| self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) |
|
|
| self.ctx_init = ctx_init |
| self.tokenized_prompts = tokenized_prompts |
| self.name_lens = name_lens |
| self.class_token_position = ctx_position |
| self.n_cls = n_cls |
| self.n_ctx = n_ctx |
| self.classnames = classnames |
|
|
| def reset(self): |
| ctx_vectors = self.ctx_init_state |
| self.ctx.copy_(ctx_vectors) |
| if self.learned_cls: |
| cls_vectors = self.cls_init_state |
| self.cls.copy_(cls_vectors) |
|
|
| def reset_classnames(self, classnames, arch): |
| self.n_cls = len(classnames) |
| if not self.learned_cls: |
| classnames = [name.replace("_", " ") for name in classnames] |
| name_lens = [len(self.medclip_model.text_model.tokenizer.encode(name))-2 for name in classnames] |
| prompts = [self.prompt_prefix + " " + name + "." for name in classnames] |
| else: |
| cls_vectors = torch.empty(self.n_cls, 1, self.ctx_dim, dtype=self.dtype) |
| nn.init.normal_(cls_vectors, std=0.02) |
| cls_token = "X" |
| name_lens = [1 for _ in classnames] |
| prompts = [self.prompt_prefix + " " + cls_token + "." for _ in classnames] |
|
|
| self.cls_init_state = cls_vectors.detach().clone() |
| |
| tokenized_prompts = self.medclip_model.text_model.tokenizer(prompts, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device) |
| prompts_tokens = tokenized_prompts['input_ids'] |
|
|
| with torch.no_grad(): |
| embedding = self.medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(self.dtype) |
|
|
| self.token_prefix = embedding[:, :1, :] |
| self.token_suffix = embedding[:, 1 + self.n_ctx :, :] |
|
|
| self.name_lens = name_lens |
| self.tokenized_prompts = tokenized_prompts |
| self.classnames = classnames |
|
|
| def forward(self, init=None): |
| |
| if init is not None: |
| ctx = init |
| else: |
| ctx = self.ctx |
| if ctx.dim() == 2: |
| ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) |
| elif not ctx.size()[0] == self.n_cls: |
| ctx = ctx.unsqueeze(1).expand(-1, self.n_cls, -1, -1) |
|
|
| prefix = self.token_prefix |
| suffix = self.token_suffix |
| if self.batch_size is not None: |
| |
| prefix = prefix.repeat(self.batch_size, 1, 1, 1) |
| suffix = suffix.repeat(self.batch_size, 1, 1, 1) |
|
|
| if self.learned_cls: |
| assert self.class_token_position == "end" |
| if self.class_token_position == "end": |
| if self.learned_cls: |
| cls = self.cls |
| prompts = torch.cat( |
| [ |
| prefix, |
| ctx, |
| cls, |
| suffix, |
| ], |
| dim=-2, |
| ) |
| else: |
| prompts = torch.cat( |
| [ |
| prefix, |
| ctx, |
| suffix, |
| ], |
| dim=-2, |
| ) |
| elif self.class_token_position == "middle": |
| |
| if self.split_idx is not None: |
| half_n_ctx = self.split_idx |
| else: |
| half_n_ctx = self.n_ctx // 2 |
| prompts = [] |
| for i in range(self.n_cls): |
| name_len = self.name_lens[i] |
| prefix_i = prefix[i : i + 1, :, :] |
| class_i = suffix[i : i + 1, :name_len, :] |
| suffix_i = suffix[i : i + 1, name_len:, :] |
| ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] |
| ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] |
| prompt = torch.cat( |
| [ |
| prefix_i, |
| ctx_i_half1, |
| class_i, |
| ctx_i_half2, |
| suffix_i, |
| ], |
| dim=1, |
| ) |
| prompts.append(prompt) |
| prompts = torch.cat(prompts, dim=0) |
|
|
| elif self.class_token_position == "front": |
| prompts = [] |
| for i in range(self.n_cls): |
| name_len = self.name_lens[i] |
| prefix_i = prefix[i : i + 1, :, :] |
| class_i = suffix[i : i + 1, :name_len, :] |
| suffix_i = suffix[i : i + 1, name_len:, :] |
| ctx_i = ctx[i : i + 1, :, :] |
| prompt = torch.cat( |
| [ |
| prefix_i, |
| class_i, |
| ctx_i, |
| suffix_i, |
| ], |
| dim=1, |
| ) |
| prompts.append(prompt) |
| prompts = torch.cat(prompts, dim=0) |
|
|
| else: |
| raise ValueError |
|
|
| return prompts |
|
|
| from MedCLIP.medclip import MedCLIPModel, MedCLIPVisionModel, MedCLIPVisionModelViT |
| from MedCLIP.medclip import MedCLIPProcessor |
|
|
| def load_medclip_to_cpu(): |
| model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT) |
| model.from_pretrained() |
| |
| |
| model.from_pretrained("./MedCLIP/pretrained/medclip-vit/") |
| |
| model.dtype = model.vision_model.model.embeddings.patch_embeddings.projection.weight.dtype |
| |
| |
|
|
|
|
| model.eval() |
| return model |
|
|
| class ClipTestTimeTuning(nn.Module): |
| def __init__(self, device, classnames, batch_size, criterion='cosine', arch="ViT-L/14", |
| n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False): |
| super(ClipTestTimeTuning, self).__init__() |
| self.device = device |
| self.medclip_model = load_medclip_to_cpu() |
| self.dtype = self.medclip_model.dtype |
| self.medclip_model = self.medclip_model.to(self.device) |
| self.image_encoder = self.medclip_model.vision_model |
| self.text_encoder = TextEncoder(self.medclip_model.text_model) |
| self.logit_scale = self.medclip_model.logit_scale.data |
| |
| self.prompt_learner = PromptLearner(self.medclip_model, classnames, self.device, batch_size, n_ctx, ctx_init, ctx_position, learned_cls) |
| self.criterion = criterion |
| self.l2_norm_cal = False |
| |
| |
| |
| |
|
|
| |
| def reset(self): |
| self.prompt_learner.reset() |
|
|
| def reset_classnames(self, classnames, arch): |
| self.prompt_learner.reset_classnames(classnames, arch) |
|
|
| def get_text_features(self): |
| text_features = [] |
| prompts = self.prompt_learner() |
| tokenized_prompts = self.prompt_learner.tokenized_prompts |
| t_features = self.text_encoder(prompts, tokenized_prompts) |
| text_features.append(t_features / t_features.norm(dim=-1, keepdim=True)) |
| text_features = torch.stack(text_features, dim=0) |
|
|
| return torch.mean(text_features, dim=0) |
|
|
| def inference(self, image): |
| with torch.no_grad(): |
| image_features = self.image_encoder(image.type(self.dtype)) |
|
|
| text_features = self.get_text_features() |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
| |
| if self.l2_norm_cal: |
| prompt_mean = text_features.mean(0) |
| feature_distance = text_features - prompt_mean |
| l2_norm = torch.linalg.norm(feature_distance, dim=-1) |
| l2_norm_mean = l2_norm.mean() |
| |
| |
| self.l2_norm_mean = l2_norm_mean.item() |
| |
| |
| self.l2_norm_mean_training = l2_norm_mean |
| |
| |
| |
| logit_scale = self.logit_scale.exp() |
| logits = logit_scale * image_features @ text_features.t() |
|
|
| return logits |
|
|
| def forward(self, input): |
| |
| if isinstance(input, Tuple): |
| view_0, view_1, view_2 = input |
| return self.contrast_prompt_tuning(view_0, view_1, view_2) |
| elif len(input.size()) == 2: |
| return self.directional_prompt_tuning(input) |
| else: |
| return self.inference(input) |
|
|
|
|
| def get_coop(clip_arch, test_set, device, n_ctx, ctx_init=None, learned_cls=False): |
| classnames = eval("{}_classes".format(test_set.lower())) |
| |
| model = ClipTestTimeTuning(device, classnames, None, arch=clip_arch, |
| n_ctx=n_ctx, ctx_init=ctx_init, learned_cls=learned_cls) |
|
|
| return model |
|
|
|
|