TCube_Merging / clip /custom_medclip.py
razaimam45's picture
Upload 108 files
a96891a verified
import math
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet
from .clip import load, tokenize
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
from data.imagnet_prompts import imagenet_classes
from data.fewshot_datasets import fewshot_datasets
from data.cls_to_names import *
# from data.medclip_datasets_clsnames import *
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
_tokenizer = _Tokenizer()
DOWNLOAD_ROOT='~/.cache/clip'
# class ClipImageEncoder(nn.Module):
# def __init__(self, device, arch="ViT-L/14", image_resolution=224, n_class=1000):
# super(ClipImageEncoder, self).__init__()
# clip, embed_dim, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT)
# self.encoder = clip.visual
# del clip.transformer
# torch.cuda.empty_cache()
# self.cls_head = nn.Linear(embed_dim, n_class)
# @property
# def dtype(self):
# return self.encoder.conv1.weight.dtype
# def forward(self, image):
# x = self.encoder(image.type(self.dtype))
# output = self.cls_head(x)
# return output
class TextEncoder(nn.Module):
def __init__(self, medclip_text_model):
super().__init__()
self.medclip_text_model = medclip_text_model
def forward(self, prompts_embeddings, tokenized_prompts):
output = self.medclip_text_model.model(inputs_embeds=prompts_embeddings, attention_mask=tokenized_prompts['attention_mask'])
# take the average of last four layers
# last_hidden_states = torch.stack(output['hidden_states'][-self.last_n_layer:]) # n_layer, batch, seqlen, emb_dim
# embed = last_hidden_states.permute(1,0,2,3)
# embed = embed.mean(1).mean(1) # pooling
# get 1+2+last layer
last_hidden_states = torch.stack([output['hidden_states'][1], output['hidden_states'][2], output['hidden_states'][-1]]) # n_layer, batch, seqlen, emb_dim
embed = last_hidden_states.permute(1,0,2,3).mean(2).mean(1) # pooling
# let's take only the last hidden layer
# embed = output['pooler_output']
embed = self.medclip_text_model.projection_head(embed)
return embed
class PromptLearner(nn.Module):
def __init__(self, medclip_model, classnames, device, batch_size=None, n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False):
super().__init__()
n_cls = len(classnames)
self.learned_cls = learned_cls
dtype = medclip_model.dtype
self.dtype = dtype
ctx_dim = 768 # hardcoded for now!!! medclip_model.ln_final.weight.shape[0]
self.ctx_dim = ctx_dim
self.batch_size = batch_size
self.device = device
self.medclip_model = medclip_model
# self.ctx, prompt_prefix = self.reset_prompt(ctx_dim, ctx_init, medclip_model)
if ctx_init:
# raise NotImplementedError("This part is not yet implemented.")
# use given words to initialize context vectors
print("Initializing the contect with given words: [{}]".format(ctx_init))
# breakpoint()
ctx_init = ctx_init.replace("_", " ")
if '[CLS]' in ctx_init:
ctx_list = ctx_init.split(" ")
split_idx = ctx_list.index("[CLS]")
ctx_init = ctx_init.replace("[CLS] ", "")
ctx_position = "middle"
else:
split_idx = None
self.split_idx = split_idx
n_ctx = len(ctx_init.split(" "))
# prompt = tokenize(ctx_init).to(self.device)
prompt = ctx_init
tokenized_prompts = medclip_model.text_model.tokenizer(prompt, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device)
prompts_tokens = tokenized_prompts['input_ids'] # [n_cls, 77]
with torch.no_grad():
embedding = medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(dtype) # [n_cls, 77, 768]
# embedding = medclip_model.token_embedding(prompt).type(dtype)
ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
prompt_prefix = ctx_init
else:
print("Random initialization: initializing a generic context")
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
self.prompt_prefix = prompt_prefix
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of context words (tokens): {n_ctx}")
# batch-wise prompt tuning for test-time adaptation
if self.batch_size is not None:
ctx_vectors = ctx_vectors.repeat(batch_size, 1, 1) #(N, L, D)
self.ctx_init_state = ctx_vectors.detach().clone()
self.ctx = nn.Parameter(ctx_vectors) # to be optimized
if not self.learned_cls:
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(medclip_model.text_model.tokenizer.encode(name))-2 for name in classnames] # [CLS] and [SEP] are not counted
prompts = [prompt_prefix + " " + name + "." for name in classnames]
else:
print("Random initialization: initializing a learnable class token")
cls_vectors = torch.empty(n_cls, 1, ctx_dim, dtype=dtype) # assume each learnable cls_token is only 1 word
nn.init.normal_(cls_vectors, std=0.02)
cls_token = "X"
name_lens = [1 for _ in classnames]
prompts = [prompt_prefix + " " + cls_token + "." for _ in classnames]
self.cls_init_state = cls_vectors.detach().clone()
self.cls = nn.Parameter(cls_vectors) # to be optimized
tokenized_prompts = medclip_model.text_model.tokenizer(prompts, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device)
prompts_tokens = tokenized_prompts['input_ids'] # [n_cls, 77]
with torch.no_grad():
embedding = medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(dtype) # [n_cls, 77, 768]
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
if self.learned_cls:
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + 1:, :]) # ..., EOS
else:
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
self.ctx_init = ctx_init
self.tokenized_prompts = tokenized_prompts # torch.Tensor
self.name_lens = name_lens
self.class_token_position = ctx_position
self.n_cls = n_cls
self.n_ctx = n_ctx
self.classnames = classnames
def reset(self):
ctx_vectors = self.ctx_init_state
self.ctx.copy_(ctx_vectors) # to be optimized
if self.learned_cls:
cls_vectors = self.cls_init_state
self.cls.copy_(cls_vectors)
def reset_classnames(self, classnames, arch):
self.n_cls = len(classnames)
if not self.learned_cls:
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(self.medclip_model.text_model.tokenizer.encode(name))-2 for name in classnames] # [CLS] and [SEP] are not counted
prompts = [self.prompt_prefix + " " + name + "." for name in classnames]
else:
cls_vectors = torch.empty(self.n_cls, 1, self.ctx_dim, dtype=self.dtype) # assume each learnable cls_token is only 1 word
nn.init.normal_(cls_vectors, std=0.02)
cls_token = "X"
name_lens = [1 for _ in classnames]
prompts = [self.prompt_prefix + " " + cls_token + "." for _ in classnames]
self.cls_init_state = cls_vectors.detach().clone()
tokenized_prompts = self.medclip_model.text_model.tokenizer(prompts, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device)
prompts_tokens = tokenized_prompts['input_ids']
with torch.no_grad():
embedding = self.medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(self.dtype) # [n_cls, 77, 768]
self.token_prefix = embedding[:, :1, :]
self.token_suffix = embedding[:, 1 + self.n_ctx :, :] # CLS, EOS
self.name_lens = name_lens
self.tokenized_prompts = tokenized_prompts
self.classnames = classnames
def forward(self, init=None):
# the init will be used when computing CLIP directional loss
if init is not None:
ctx = init
else:
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
elif not ctx.size()[0] == self.n_cls:
ctx = ctx.unsqueeze(1).expand(-1, self.n_cls, -1, -1)
prefix = self.token_prefix
suffix = self.token_suffix
if self.batch_size is not None:
# This way only works for single-gpu setting (could pass batch size as an argument for forward())
prefix = prefix.repeat(self.batch_size, 1, 1, 1)
suffix = suffix.repeat(self.batch_size, 1, 1, 1)
if self.learned_cls:
assert self.class_token_position == "end"
if self.class_token_position == "end":
if self.learned_cls:
cls = self.cls
prompts = torch.cat(
[
prefix, # (n_cls, 1, dim)
ctx, # (n_cls, n_ctx, dim)
cls, # (n_cls, 1, dim)
suffix, # (n_cls, *, dim)
],
dim=-2,
)
else:
prompts = torch.cat(
[
prefix, # (n_cls, 1, dim)
ctx, # (n_cls, n_ctx, dim)
suffix, # (n_cls, *, dim)
],
dim=-2,
)
elif self.class_token_position == "middle":
# TODO: to work with a batch of prompts
if self.split_idx is not None:
half_n_ctx = self.split_idx # split the ctx at the position of [CLS] in `ctx_init`
else:
half_n_ctx = self.n_ctx // 2
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = prefix[i : i + 1, :, :]
class_i = suffix[i : i + 1, :name_len, :]
suffix_i = suffix[i : i + 1, name_len:, :]
ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
ctx_i_half1, # (1, n_ctx//2, dim)
class_i, # (1, name_len, dim)
ctx_i_half2, # (1, n_ctx//2, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
elif self.class_token_position == "front":
prompts = []
for i in range(self.n_cls):
name_len = self.name_lens[i]
prefix_i = prefix[i : i + 1, :, :]
class_i = suffix[i : i + 1, :name_len, :]
suffix_i = suffix[i : i + 1, name_len:, :]
ctx_i = ctx[i : i + 1, :, :]
prompt = torch.cat(
[
prefix_i, # (1, 1, dim)
class_i, # (1, name_len, dim)
ctx_i, # (1, n_ctx, dim)
suffix_i, # (1, *, dim)
],
dim=1,
)
prompts.append(prompt)
prompts = torch.cat(prompts, dim=0)
else:
raise ValueError
return prompts
from MedCLIP.medclip import MedCLIPModel, MedCLIPVisionModel, MedCLIPVisionModelViT
from MedCLIP.medclip import MedCLIPProcessor
def load_medclip_to_cpu():
model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
model.from_pretrained()
# breakpoint()
# model.from_pretrained("/l/users/asif.hanif/pre-trained-models/vlps/medclip/pretrained/medclip-vit/")
model.from_pretrained("./MedCLIP/pretrained/medclip-vit/")
# for vit
model.dtype = model.vision_model.model.embeddings.patch_embeddings.projection.weight.dtype
# for Resnet
# model.dtype = model.vision_model.model.conv1.weight.dtype
model.eval()
return model
class ClipTestTimeTuning(nn.Module):
def __init__(self, device, classnames, batch_size, criterion='cosine', arch="ViT-L/14",
n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False):
super(ClipTestTimeTuning, self).__init__()
self.device = device
self.medclip_model = load_medclip_to_cpu()
self.dtype = self.medclip_model.dtype
self.medclip_model = self.medclip_model.to(self.device)
self.image_encoder = self.medclip_model.vision_model
self.text_encoder = TextEncoder(self.medclip_model.text_model)
self.logit_scale = self.medclip_model.logit_scale.data
# prompt tuning
self.prompt_learner = PromptLearner(self.medclip_model, classnames, self.device, batch_size, n_ctx, ctx_init, ctx_position, learned_cls)
self.criterion = criterion
self.l2_norm_cal = False
# @property
# def dtype(self):
# return self.image_encoder.conv1.weight.dtype
# restore the initial state of the prompt_learner (tunable prompt)
def reset(self):
self.prompt_learner.reset()
def reset_classnames(self, classnames, arch):
self.prompt_learner.reset_classnames(classnames, arch)
def get_text_features(self):
text_features = []
prompts = self.prompt_learner()
tokenized_prompts = self.prompt_learner.tokenized_prompts
t_features = self.text_encoder(prompts, tokenized_prompts)
text_features.append(t_features / t_features.norm(dim=-1, keepdim=True))
text_features = torch.stack(text_features, dim=0)
return torch.mean(text_features, dim=0)
def inference(self, image):
with torch.no_grad():
image_features = self.image_encoder(image.type(self.dtype))
text_features = self.get_text_features()
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
#[c-tpt] --------------------------------------------
if self.l2_norm_cal:
prompt_mean = text_features.mean(0)
feature_distance = text_features - prompt_mean
l2_norm = torch.linalg.norm(feature_distance, dim=-1)
l2_norm_mean = l2_norm.mean()
#for saving to csv file
self.l2_norm_mean = l2_norm_mean.item()
#for training
self.l2_norm_mean_training = l2_norm_mean
#-----------------------------------------------------
logit_scale = self.logit_scale.exp()
logits = logit_scale * image_features @ text_features.t()
return logits
def forward(self, input):
# breakpoint()
if isinstance(input, Tuple):
view_0, view_1, view_2 = input
return self.contrast_prompt_tuning(view_0, view_1, view_2)
elif len(input.size()) == 2:
return self.directional_prompt_tuning(input)
else:
return self.inference(input)
def get_coop(clip_arch, test_set, device, n_ctx, ctx_init=None, learned_cls=False):
classnames = eval("{}_classes".format(test_set.lower()))
model = ClipTestTimeTuning(device, classnames, None, arch=clip_arch,
n_ctx=n_ctx, ctx_init=ctx_init, learned_cls=learned_cls)
return model