| | |
| | import sys |
| | sys.path.append('versatile_diffusion') |
| | import os |
| | import os.path as osp |
| | import PIL |
| | from PIL import Image |
| | from pathlib import Path |
| | import numpy as np |
| | import numpy.random as npr |
| |
|
| | import torch |
| | import torchvision.transforms as tvtrans |
| | from lib.cfg_helper import model_cfg_bank |
| | from lib.model_zoo import get_model |
| | from lib.model_zoo.ddim_vd import DDIMSampler_VD |
| | from lib.experiments.sd_default import color_adjust, auto_merge_imlist |
| | from torch.utils.data import DataLoader, Dataset |
| |
|
| | from lib.model_zoo.vd import VD |
| | from lib.cfg_holder import cfg_unique_holder as cfguh |
| | from lib.cfg_helper import get_command_line_args, cfg_initiates, load_cfg_yaml |
| | import matplotlib.pyplot as plt |
| | from skimage.transform import resize, downscale_local_mean |
| |
|
| |
|
| | def regularize_image(x): |
| | BICUBIC = PIL.Image.Resampling.BICUBIC |
| | if isinstance(x, str): |
| | x = Image.open(x).resize([512, 512], resample=BICUBIC) |
| | x = tvtrans.ToTensor()(x) |
| | elif isinstance(x, PIL.Image.Image): |
| | x = x.resize([512, 512], resample=BICUBIC) |
| | x = tvtrans.ToTensor()(x) |
| | elif isinstance(x, np.ndarray): |
| | x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC) |
| | x = tvtrans.ToTensor()(x) |
| | elif isinstance(x, torch.Tensor): |
| | pass |
| | else: |
| | assert False, 'Unknown image type' |
| |
|
| | assert (x.shape[1]==512) & (x.shape[2]==512), \ |
| | 'Wrong image size' |
| | return x |
| |
|
| | |
| | cfgm_name = 'vd_noema' |
| | sampler = DDIMSampler_VD |
| | pth = 'versatile_diffusion/pretrained/vd-four-flow-v1-0-fp16-deprecated.pth' |
| | cfgm = model_cfg_bank()(cfgm_name) |
| | net = get_model()(cfgm) |
| | sd = torch.load(pth, map_location='cpu') |
| | net.load_state_dict(sd, strict=False) |
| |
|
| | |
| | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | net.clip.cuda(0) |
| | net.autokl.cuda(0) |
| |
|
| | sampler = sampler(net) |
| | sampler.model.model.diffusion_model.device = device |
| | sampler.model.model.diffusion_model.half().to(device) |
| | batch_size = 1 |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | n_samples = 1 |
| | ddim_steps = 50 |
| | ddim_eta = 0 |
| | scale = 7.5 |
| | xtype = 'image' |
| | ctype = 'prompt' |
| | net.autokl.half() |
| |
|
| | torch.manual_seed(0) |
| |
|
| | net.clip = net.clip.to(device) |
| |
|
| | def generate_image(sub, image_id, annot, strength=0.75, mixing=0.4): |
| | |
| | im_id = image_id |
| | |
| | pred_text = np.load(f'data/predicted_features/subj{sub:02d}/nsd_cliptext_predtest_nsdgeneral.npy') |
| | pred_vision = np.load(f'data/predicted_features/subj{sub:02d}/nsd_clipvision_predtest_nsdgeneral.npy') |
| | pred_text = torch.tensor(pred_text).half().to(device) |
| | pred_vision = torch.tensor(pred_vision).half().to(device) |
| |
|
| | zim = Image.open(f'results/vdvae/subj{sub:02d}/{image_id}.png') |
| | test_img = Image.open(f'data/nsddata_stimuli/test_images/{image_id}.png') |
| | test_img_path = f'scripts/images/original_image.png' |
| | test_img.save(test_img_path) |
| |
|
| | zim = regularize_image(zim) |
| | zin = zim * 2 - 1 |
| | zin = zin.unsqueeze(0).to(device).half() |
| | init_latent = net.autokl_encode(zin) |
| |
|
| | sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) |
| | t_enc = int(strength * ddim_steps) |
| | z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]).to(device)) |
| |
|
| | dummy = '' |
| | utx = net.clip_encode_text(dummy).to(device).half() |
| | dummy = torch.zeros((1, 3, 224, 224)).to(device) |
| | uim = net.clip_encode_vision(dummy).to(device).half() |
| |
|
| | z_enc = z_enc.to(device) |
| |
|
| | |
| | h, w = 512,512 |
| | shape = [n_samples, 4, h//8, w//8] |
| |
|
| | pred_text = np.load(f'data/predicted_features/subj{sub:02d}/nsd_cliptext_predtest_nsdgeneral.npy') |
| | with torch.no_grad(): |
| | pred_text[image_id] = net.clip_encode_text([annot]).to('cpu').numpy().mean(0) |
| | pred_text = torch.tensor(pred_text).half().to(device) |
| | ctx = pred_text[image_id].unsqueeze(0).to(device) |
| | cim = pred_vision[image_id].unsqueeze(0).to(device) |
| |
|
| | z = sampler.decode_dc( |
| | x_latent=z_enc, |
| | first_conditioning=[uim, cim], |
| | second_conditioning=[utx, ctx], |
| | t_start=t_enc, |
| | unconditional_guidance_scale=7.5, |
| | xtype='image', |
| | first_ctype='vision', |
| | second_ctype='prompt', |
| | mixed_ratio=(1 - mixing), |
| | ) |
| |
|
| | z = z.to(device).half() |
| | x = net.autokl_decode(z) |
| | |
| | color_adj='None' |
| | color_adj_flag = (color_adj != 'none') and (color_adj != 'None') and (color_adj is not None) |
| | color_adj_simple = (color_adj == 'Simple') or color_adj == 'simple' |
| | color_adj_keep_ratio = 0.5 |
| | |
| | if color_adj_flag and (ctype == 'vision'): |
| | x_adj = [] |
| | for xi in x: |
| | color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to) |
| | xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple) |
| | x_adj.append(xi_adj) |
| | x = x_adj |
| | else: |
| | x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) |
| | x = [tvtrans.ToPILImage()(xi) for xi in x] |
| |
|
| | |
| | x[0].save('scripts/images/reconstructed.png'.format(sub, im_id)) |
| | |
| | |
| |
|
| | output_path = f'scripts/images/reconstructed.png' |
| | |
| |
|
| | return test_img_path, output_path |