DiffICM / 3_ControlNet /test_dataset.py
Qiyp's picture
code of stage1 & 3, remove large files
1633fcc
Raw
History Blame Contribute Delete
4.53 kB
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
import os
from PIL import Image
from torchvision import transforms
class CustomCocoDataset(Dataset):
def __init__(self, json_file, img_folder, common_transform=None):
self.coco = COCO(json_file)
self.img_folder = img_folder
self.ids = list(self.coco.imgToAnns.keys())
self.common_transform = common_transform
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
img_id = self.ids[index]
img_info = self.coco.loadImgs(img_id)[0]
path = img_info['file_name']
img_path = os.path.join(self.img_folder, path)
image = Image.open(img_path).convert('RGB')
# Perform a random crop
i, j, h, w = transforms.RandomResizedCrop.get_params(
image, scale=(0.9, 1.0), ratio=(1.0, 1.0)) # Ensure the same crop for both images
cropped_image = transforms.functional.crop(image, i, j, h, w)
# Resize to different resolutions
jpg_image = transforms.functional.resize(cropped_image, 512, interpolation=transforms.InterpolationMode.BICUBIC)
hint_image = transforms.functional.resize(cropped_image, 448, interpolation=transforms.InterpolationMode.BICUBIC)
# Apply common transformations
if self.common_transform is not None:
jpg_image = self.common_transform(jpg_image)
hint_image = self.common_transform(hint_image)
ann_ids = self.coco.getAnnIds(imgIds=img_id)
anns = self.coco.loadAnns(ann_ids)
captions = [ann['caption'] for ann in anns]
combined_caption = ' '.join(captions)
return dict(jpg=jpg_image, txt=combined_caption, hint=hint_image)
class CustomCocoDataset(Dataset):
def __init__(self, json_file, img_folder, common_transform=None):
self.coco = COCO(json_file)
self.img_folder = img_folder
self.ids = list(self.coco.imgToAnns.keys())
self.common_transform = common_transform
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
img_id = self.ids[index]
img_info = self.coco.loadImgs(img_id)[0]
path = img_info['file_name']
img_path = os.path.join(self.img_folder, path)
image = Image.open(img_path).convert('RGB')
# Perform a random crop
i, j, h, w = transforms.RandomResizedCrop.get_params(
image, scale=(0.95, 1.0), ratio=(1.0, 1.0)) # Ensure the same crop for both images
cropped_image = transforms.functional.crop(image, i, j, h, w)
# Resize to different resolutions
jpg_image = transforms.functional.resize(cropped_image, 512, interpolation=transforms.InterpolationMode.BICUBIC)
hint_image = transforms.functional.resize(cropped_image, 448, interpolation=transforms.InterpolationMode.BICUBIC)
# Apply common transformations
if self.common_transform is not None:
jpg_image = self.common_transform(jpg_image)
hint_image = self.common_transform(hint_image)
ann_ids = self.coco.getAnnIds(imgIds=img_id)
anns = self.coco.loadAnns(ann_ids)
# captions = [ann['caption'] for ann in anns]
captions = [ann['caption'].replace('\n', ' ') for ann in anns]
combined_caption = ' '.join(captions)
return dict(jpg=jpg_image, txt=combined_caption, hint=hint_image)
def main():
# Define the common transformations
common_transform = transforms.Compose([
transforms.ToTensor(), # Converts to tensor and normalizes to [0, 1]
# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalizes to [-1, 1]
])
# Instantiate the dataset
dataset = CustomCocoDataset(
json_file='/home/t2vg-a100-G4-1/projects/dataset/annotations/captions_train2017.json',
img_folder='/home/t2vg-a100-G4-1/projects/dataset/train2017',
common_transform=common_transform
)
# Create the DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
# Get the first batch
for batch in dataloader:
jpg_image = batch['jpg']
# Print the min and max values in the image tensor
print(f'JPG Image Min Value: {jpg_image.min().item()}')
print(f'JPG Image Max Value: {jpg_image.max().item()}')
# Exit after the first batch
# break
if __name__ == "__main__":
main()