| import numpy as np |
| import os |
| from torch.utils.data import Dataset |
| import torch |
| from utils import load_normal, load_ssao, load_img, depthToPoint, process_normal, load_depth, Augment_RGB_torch |
| import torch.nn.functional as F |
| import random |
|
|
| augment = Augment_RGB_torch() |
| transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')] |
|
|
| |
| class DataLoaderTrain(Dataset): |
| def __init__(self, rgb_dir, img_options=None, target_transform=None, debug=False): |
| super(DataLoaderTrain, self).__init__() |
|
|
| self.target_transform = target_transform |
| |
| gt_dir = 'shadow_free' |
| input_dir = 'origin' |
| depth_dir = 'depth' |
| normal_dir = 'normal' |
|
|
| clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) |
| noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) |
| depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir))) |
| normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir))) |
|
|
| self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files] |
| self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files] |
| self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files] |
| self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files] |
| self.img_options = img_options |
|
|
| if debug: |
| self.tar_size = 100 |
| else: |
| self.tar_size = len(self.noisy_filenames) |
|
|
| def __len__(self): |
| return self.tar_size |
|
|
| def __getitem__(self, index): |
| tar_index = index % self.tar_size |
| |
| clean = np.float32(load_img(self.clean_filenames[tar_index])) |
| noisy = np.float32(load_img(self.noisy_filenames[tar_index])) |
| depth = np.float32(load_depth(self.depth_filenames[tar_index])) |
| normal = np.float32(load_normal(self.normal_filenames[tar_index])) |
|
|
| point = depthToPoint(60, depth) |
|
|
| normal = process_normal(normal) |
|
|
| clean = torch.from_numpy(clean) |
| noisy = torch.from_numpy(noisy) |
| depth = torch.from_numpy(depth) |
| point = torch.from_numpy(point) |
| normal = torch.from_numpy(normal) |
|
|
| point = point / (2 * point[:,:,2].mean()) |
|
|
| clean = clean.permute(2,0,1) |
| noisy = noisy.permute(2,0,1) |
| point = point.permute(2,0,1) |
| normal = normal.permute(2,0,1) |
|
|
| clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] |
| noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] |
| depth_filename = os.path.split(self.depth_filenames[tar_index])[-1] |
| normal_filename = os.path.split(self.normal_filenames[tar_index])[-1] |
|
|
|
|
| augment.rotate = random.randint(-20,20) |
| apply_trans = transforms_aug[random.randint(0, 2)] |
|
|
| |
| clean = getattr(augment, apply_trans)(clean) |
| noisy = getattr(augment, apply_trans)(noisy) |
| point = getattr(augment, apply_trans)(point) |
| normal = getattr(augment, apply_trans)(normal) |
|
|
|
|
| |
| ps = self.img_options['patch_size'] |
| scale = 1 |
|
|
| H = noisy.shape[1] |
| W = noisy.shape[2] |
| scaled_ps = (int)(scale * ps) |
| if H - scaled_ps != 0 or W - scaled_ps != 0: |
| r = np.random.randint(0, H - scaled_ps + 1) |
| c = np.random.randint(0, W - scaled_ps + 1) |
| clean = clean [:, r:r + scaled_ps, c:c + scaled_ps] |
| noisy = noisy [:, r:r + scaled_ps, c:c + scaled_ps] |
| point = point [:, r:r + scaled_ps, c:c + scaled_ps] |
| normal = normal [:, r:r + scaled_ps, c:c + scaled_ps] |
|
|
| |
| if scale != 1: |
| clean = F.interpolate(clean.unsqueeze(0), size=[ps, ps], mode='bilinear') |
| noisy = F.interpolate(noisy.unsqueeze(0), size=[ps, ps], mode='bilinear') |
| point = F.interpolate(point.unsqueeze(0), size=[ps, ps], mode='nearest') |
| normal = F.interpolate(normal.unsqueeze(0), size=[ps, ps], mode='nearest') |
| return clean.squeeze(0), noisy.squeeze(0), point.squeeze(0), normal.squeeze(0), noisy_filename |
|
|
| return clean, noisy, point, normal, clean_filename, noisy_filename |
|
|
|
|
| |
| class DataLoaderVal(Dataset): |
| def __init__(self, rgb_dir, target_transform=None, debug=False): |
| super(DataLoaderVal, self).__init__() |
|
|
| self.target_transform = target_transform |
| |
| gt_dir = 'shadow_free' |
| input_dir = 'origin' |
| depth_dir = 'depth' |
| normal_dir = 'normal' |
| |
| clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) |
| noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) |
| depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir))) |
| normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir))) |
|
|
| self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files] |
| self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files] |
| self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files] |
| self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files] |
|
|
| if debug: |
| self.tar_size = 10 |
| else: |
| self.tar_size = len(self.noisy_filenames) |
|
|
| def __len__(self): |
| return self.tar_size |
|
|
| def __getitem__(self, index): |
| tar_index = index % self.tar_size |
| clean = np.float32(load_img(self.clean_filenames[tar_index])) |
| noisy = np.float32(load_img(self.noisy_filenames[tar_index])) |
| depth = np.float32(load_depth(self.depth_filenames[tar_index])) |
| normal = np.float32(load_normal(self.normal_filenames[tar_index])) |
|
|
| point = depthToPoint(60, depth) |
| normal = process_normal(normal) |
| point = point / (2 * point[:,:,2].mean()) |
|
|
| clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] |
| noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] |
|
|
| clean = torch.from_numpy(clean) |
| noisy = torch.from_numpy(noisy) |
| point = torch.from_numpy(point) |
| normal = torch.from_numpy(normal) |
|
|
| |
| clean = clean.permute(2,0,1) |
| noisy = noisy.permute(2,0,1) |
| point = point.permute(2,0,1) |
| normal = normal.permute(2,0,1) |
|
|
|
|
| return clean, noisy, point, normal, clean_filename, noisy_filename |
| |
|
|
|
|
|
|
| |
| class DataLoaderTest(Dataset): |
| def __init__(self, rgb_dir, target_transform=None, debug=False): |
| super(DataLoaderTest, self).__init__() |
|
|
| self.target_transform = target_transform |
| |
| |
| input_dir = 'origin' |
| depth_dir = 'depth' |
| normal_dir = 'normal' |
| |
| |
| noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) |
| depth_files = sorted(os.listdir(os.path.join(rgb_dir, depth_dir))) |
| normal_files = sorted(os.listdir(os.path.join(rgb_dir, normal_dir))) |
|
|
| |
| self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files] |
| self.depth_filenames = [os.path.join(rgb_dir, depth_dir, x) for x in depth_files] |
| self.normal_filenames = [os.path.join(rgb_dir, normal_dir, x) for x in normal_files] |
|
|
| if debug: |
| self.tar_size = 10 |
| else: |
| self.tar_size = len(self.noisy_filenames) |
|
|
| def __len__(self): |
| return self.tar_size |
|
|
| def __getitem__(self, index): |
| tar_index = index % self.tar_size |
| |
| noisy = np.float32(load_img(self.noisy_filenames[tar_index])) |
| depth = np.float32(load_depth(self.depth_filenames[tar_index])) |
| normal = np.float32(load_normal(self.normal_filenames[tar_index])) |
|
|
| point = depthToPoint(60, depth) |
| normal = process_normal(normal) |
| point = point / (2 * point[:,:,2].mean()) |
|
|
| |
| noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] |
|
|
| |
| noisy = torch.from_numpy(noisy) |
| point = torch.from_numpy(point) |
| normal = torch.from_numpy(normal) |
|
|
| |
| |
| noisy = noisy.permute(2,0,1) |
| point = point.permute(2,0,1) |
| normal = normal.permute(2,0,1) |
|
|
|
|
| return noisy, noisy, point, normal, noisy_filename, noisy_filename |
|
|
|
|