DiffICM / 3_ControlNet /build_dataset.py
Qiyp's picture
code of stage1 & 3, remove large files
1633fcc
Raw
History Blame Contribute Delete
3.78 kB
import os
import random
import math
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class CustomCocoDataset(Dataset):
def __init__(self, img_folder, img_size=512, hint_size=448):
self.img_folder = img_folder
self.img_size = img_size
self.hint_size = hint_size
self.ids = [os.path.splitext(f)[0] for f in os.listdir(img_folder) if f.endswith(('.jpg', '.jpeg', '.png'))]
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
img_id = self.ids[index]
img_path = os.path.join(self.img_folder, img_id + '.png')
image = Image.open(img_path).convert('RGB')
# Perform a random crop using the custom random_crop_arr function
cropped_image = random_crop_arr(image, self.img_size, min_crop_frac=0.8, max_crop_frac=1.0)
# Convert cropped image back to PIL for further processing
cropped_image = Image.fromarray(cropped_image)
# Resize to different resolutions
jpg_image = transforms.functional.to_tensor(cropped_image)
hint_image = transforms.functional.resize(cropped_image, (self.hint_size, self.hint_size), interpolation=transforms.InterpolationMode.BICUBIC)
hint_image = transforms.functional.to_tensor(hint_image)
# Set captions to an empty string
prompt = ""
return dict(jpg=jpg_image, txt=prompt, hint=hint_image)
def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
# We are not on a new enough PIL to support the reducing_gap
# argument, which uses BOX downsampling at powers of two first.
# Thus, we do it by hand to improve downsample quality.
while min(*pil_image.size) >= 2 * smaller_dim_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = smaller_dim_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = random.randrange(arr.shape[0] - image_size + 1)
crop_x = random.randrange(arr.shape[1] - image_size + 1)
return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
if __name__ == "__main__":
dataset = CustomCocoDataset("/home/t2vg-a100-G4-1/projects/dataset/LSDIR_raw/images/train")
print(len(dataset))
print(dataset[0])
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset, batch_size=4, num_workers=2,
pin_memory=True, drop_last=True)
# 从 DataLoader 中取出一个批次
batch = next(iter(dataloader))
# 提取批次中的 jpg_image 和 hint_image
jpg_images = batch['jpg']
hint_images = batch['hint']
prompts = batch['txt']
# 打印提示语
print(f"Prompt: {prompts}")
# 可视化并保存第一个batch的图像
import matplotlib.pyplot as plt
for i in range(len(jpg_images)):
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title(f"JPG Image {i+1} (512x512)")
plt.imshow(jpg_images[i].permute(1, 2, 0)) # 转换维度以便imshow使用
plt.subplot(1, 2, 2)
plt.title(f"Hint Image {i+1} (448x448)")
plt.imshow(hint_images[i].permute(1, 2, 0)) # 转换维度以便imshow使用
# 保存图像到文件
plt.savefig(f'output_image_{i+1}.png')
# 关闭当前图像,释放内存
plt.close()