| import matplotlib.pyplot as plt |
| import numpy as np |
| from PIL import Image |
|
|
| import torch |
| from torchvision.utils import make_grid |
|
|
|
|
| def plt_batch( |
| photos: torch.Tensor, |
| sketch: torch.Tensor, |
| step: int, |
| prompt: str, |
| save_path: str, |
| name: str, |
| dpi: int = 300 |
| ): |
| if photos.shape != sketch.shape: |
| raise ValueError("photos and sketch must have the same dimensions") |
|
|
| plt.figure() |
| plt.subplot(1, 2, 1) |
| grid = make_grid(photos, normalize=True, pad_value=2) |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
| plt.imshow(ndarr) |
| plt.axis("off") |
| plt.title("Generated sample") |
|
|
| plt.subplot(1, 2, 2) |
| grid = make_grid(sketch, normalize=False, pad_value=2) |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
| plt.imshow(ndarr) |
| plt.axis("off") |
| plt.title(f"Rendering result - {step} steps") |
|
|
| plt.suptitle(insert_newline(prompt), fontsize=10) |
|
|
| plt.tight_layout() |
| plt.savefig(f"{save_path}/{name}.png", dpi=dpi) |
| plt.close() |
|
|
|
|
| def plt_triplet( |
| photos: torch.Tensor, |
| sketch: torch.Tensor, |
| style: torch.Tensor, |
| step: int, |
| prompt: str, |
| save_path: str, |
| name: str, |
| dpi: int = 300 |
| ): |
| if photos.shape != sketch.shape: |
| raise ValueError("photos and sketch must have the same dimensions") |
|
|
| plt.figure() |
| plt.subplot(1, 3, 1) |
| grid = make_grid(photos, normalize=True, pad_value=2) |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
| plt.imshow(ndarr) |
| plt.axis("off") |
| plt.title("Generated sample") |
|
|
| plt.subplot(1, 3, 2) |
| |
| grid = make_grid(style, normalize=False, pad_value=2) |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
| plt.imshow(ndarr) |
| plt.axis("off") |
| plt.title(f"Style") |
|
|
| plt.subplot(1, 3, 3) |
| |
| grid = make_grid(sketch, normalize=False, pad_value=2) |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
| plt.imshow(ndarr) |
| plt.axis("off") |
| plt.title(f"Rendering result - {step} steps") |
|
|
| plt.suptitle(insert_newline(prompt), fontsize=10) |
|
|
| plt.tight_layout() |
| plt.savefig(f"{save_path}/{name}.png", dpi=dpi) |
| plt.close() |
|
|
|
|
| def insert_newline(string, point=9): |
| |
| words = string.split() |
| if len(words) <= point: |
| return string |
|
|
| word_chunks = [words[i:i + point] for i in range(0, len(words), point)] |
| new_string = "\n".join(" ".join(chunk) for chunk in word_chunks) |
| return new_string |
|
|
|
|
| def log_tensor_img(inputs, output_dir, output_prefix="input", norm=False, dpi=300): |
| grid = make_grid(inputs, normalize=norm, pad_value=2) |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
| plt.imshow(ndarr) |
| plt.axis("off") |
| plt.tight_layout() |
| plt.savefig(f"{output_dir}/{output_prefix}.png", dpi=dpi) |
| plt.close() |
|
|
|
|
| def plt_tensor_img(tensor, title, save_path, name, dpi=500): |
| grid = make_grid(tensor, normalize=True, pad_value=2) |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
| plt.imshow(ndarr) |
| plt.axis("off") |
| plt.title(f"{title}") |
| plt.savefig(f"{save_path}/{name}.png", dpi=dpi) |
| plt.close() |
|
|
|
|
| def save_tensor_img(tensor, save_path, name, dpi=500): |
| grid = make_grid(tensor, normalize=True, pad_value=2) |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
| plt.imshow(ndarr) |
| plt.axis("off") |
| plt.tight_layout() |
| plt.savefig(f"{save_path}/{name}.png", dpi=dpi) |
| plt.close() |
|
|
|
|
| def plt_attn(attn, threshold_map, inputs, inds, output_path): |
| |
| plt.figure(figsize=(10, 5)) |
|
|
| plt.subplot(1, 3, 1) |
| main_im = make_grid(inputs, normalize=True, pad_value=2) |
| main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) |
| plt.imshow(main_im, interpolation='nearest') |
| plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') |
| plt.title("input img") |
| plt.axis("off") |
|
|
| plt.subplot(1, 3, 2) |
| plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) |
| plt.title("attn map") |
| plt.axis("off") |
|
|
| plt.subplot(1, 3, 3) |
| threshold_map_ = (threshold_map - threshold_map.min()) / \ |
| (threshold_map.max() - threshold_map.min()) |
| plt.imshow(np.nan_to_num(threshold_map_), interpolation='nearest', vmin=0, vmax=1) |
| plt.title("prob softmax") |
| plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') |
| plt.axis("off") |
|
|
| plt.tight_layout() |
| plt.savefig(output_path) |
| plt.close() |
|
|
|
|
| def fix_image_scale(im): |
| im_np = np.array(im) / 255 |
| height, width = im_np.shape[0], im_np.shape[1] |
| max_len = max(height, width) + 20 |
| new_background = np.ones((max_len, max_len, 3)) |
| y, x = max_len // 2 - height // 2, max_len // 2 - width // 2 |
| new_background[y: y + height, x: x + width] = im_np |
| new_background = (new_background / new_background.max() |
| * 255).astype(np.uint8) |
| new_im = Image.fromarray(new_background) |
| return new_im |
|
|