| | |
| | |
| | import torch |
| | import numpy as np |
| | from typing import List, Tuple, Union |
| |
|
| | from pytorch3d.renderer import ( |
| | PerspectiveCameras, |
| | MeshRenderer, |
| | MeshRasterizer, |
| | SoftPhongShader, |
| | RasterizationSettings, |
| | PointLights, |
| | TexturesVertex |
| | ) |
| |
|
| | from pytorch3d.structures import Meshes |
| | from pytorch3d.renderer.camera_conversions import _cameras_from_opencv_projection |
| |
|
| | def update_intrinsics_from_bbox( |
| | K_org: torch.Tensor, bbox: torch.Tensor |
| | ) -> Tuple[torch.Tensor, List[Tuple[int, int]]]: |
| | """ |
| | Update intrinsic matrix K according to the given bounding box. |
| | |
| | Args: |
| | K_org (torch.Tensor): Original intrinsic matrix of shape (B, 3, 3). |
| | bbox (torch.Tensor): Bounding boxes of shape (B, 4) in (left, top, right, bottom) format. |
| | |
| | Returns: |
| | K_new (torch.Tensor): Updated intrinsics with shape (B, 4, 4). |
| | image_sizes (List[Tuple[int, int]]): List of image sizes (height, width) for each bbox. |
| | """ |
| | device, dtype = K_org.device, K_org.dtype |
| |
|
| | |
| | K_new = torch.zeros((K_org.shape[0], 4, 4), device=device, dtype=dtype) |
| | K_new[:, :3, :3] = K_org.clone() |
| | K_new[:, 2, 2] = 0 |
| | K_new[:, 2, -1] = 1 |
| | K_new[:, -1, 2] = 1 |
| |
|
| | image_sizes = [] |
| | for idx, box in enumerate(bbox): |
| | left, top, right, bottom = box |
| | cx, cy = K_new[idx, 0, 2], K_new[idx, 1, 2] |
| |
|
| | |
| | new_cx = cx - left |
| | new_cy = cy - top |
| |
|
| | |
| | new_height = max(bottom - top, 1) |
| | new_width = max(right - left, 1) |
| |
|
| | |
| | new_cx = new_width - new_cx |
| | new_cy = new_height - new_cy |
| |
|
| | K_new[idx, 0, 2] = new_cx |
| | K_new[idx, 1, 2] = new_cy |
| |
|
| | image_sizes.append((int(new_height), int(new_width))) |
| |
|
| | return K_new, image_sizes |
| |
|
| | class Renderer(): |
| | """ |
| | Renderer class using PyTorch3D for mesh rendering with Phong shading. |
| | |
| | Attributes: |
| | width (int): Target image width. |
| | height (int): Target image height. |
| | focal_length (Union[float, Tuple[float, float]]): Camera focal length(s). |
| | device (torch.device): Device to run rendering on. |
| | renderer (MeshRenderer): PyTorch3D mesh renderer. |
| | cameras (PerspectiveCameras): Camera object. |
| | lights (PointLights): Lighting setup for rendering. |
| | """ |
| | def __init__( |
| | self, |
| | width: int, |
| | height: int, |
| | focal_length: Union[float, Tuple[float, float]], |
| | device: torch.device, |
| | bin_size: int = 512, |
| | max_faces_per_bin: int = 200000, |
| | ): |
| |
|
| | self.width = width |
| | self.height = height |
| | self.focal_length = focal_length |
| | self.device = device |
| |
|
| | |
| | self._initialize_camera_params() |
| |
|
| | |
| | self.lights = PointLights( |
| | device=device, |
| | location = ((0.0, -1.5, -1.5),), |
| | ambient_color=((0.75, 0.75, 0.75),), |
| | diffuse_color=((0.25, 0.25, 0.25),), |
| | specular_color=((0.02, 0.02, 0.02),) |
| | ) |
| | |
| | |
| | self._create_renderer(bin_size, max_faces_per_bin) |
| |
|
| | def _create_renderer(self, bin_size: int, max_faces_per_bin: int): |
| | """ |
| | Create the PyTorch3D MeshRenderer with rasterizer and shader. |
| | """ |
| | self.renderer = MeshRenderer( |
| | rasterizer=MeshRasterizer( |
| | raster_settings=RasterizationSettings( |
| | image_size=self.image_sizes[0], |
| | blur_radius=1e-5, |
| | bin_size=bin_size, |
| | max_faces_per_bin=max_faces_per_bin, |
| | ) |
| | ), |
| | shader=SoftPhongShader( |
| | device=self.device, |
| | lights=self.lights, |
| | ), |
| | ) |
| |
|
| | def _initialize_camera_params(self): |
| | """ |
| | Initialize camera intrinsics and extrinsics. |
| | """ |
| | |
| | self.R = torch.eye(3, device=self.device).unsqueeze(0) |
| | self.T = torch.zeros(1, 3, device=self.device) |
| |
|
| | |
| | if isinstance(self.focal_length, (list, tuple)): |
| | fx, fy = self.focal_length |
| | else: |
| | fx = fy = self.focal_length |
| |
|
| | self.K = torch.tensor( |
| | [[fx, 0, self.width / 2], |
| | [0, fy, self.height / 2], |
| | [0, 0, 1]], |
| | device=self.device, |
| | dtype=torch.float32, |
| | ).unsqueeze(0) |
| |
|
| | self.bboxes = torch.tensor([[0, 0, self.width, self.height]], dtype=torch.float32) |
| | self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes) |
| |
|
| | |
| | self.cameras = self._create_camera_from_cv() |
| |
|
| | def _create_camera_from_cv( |
| | self, |
| | R: torch.Tensor = None, |
| | T: torch.Tensor = None, |
| | K: torch.Tensor = None, |
| | image_size: torch.Tensor = None, |
| | ) -> PerspectiveCameras: |
| | """ |
| | Create a PyTorch3D camera from OpenCV-style intrinsics and extrinsics. |
| | """ |
| | if R is None: |
| | R = self.R |
| | if T is None: |
| | T = self.T |
| | if K is None: |
| | K = self.K |
| | if image_size is None: |
| | image_size = torch.tensor(self.image_sizes, device=self.device) |
| |
|
| | cameras = _cameras_from_opencv_projection(R, T, K, image_size) |
| | return cameras |
| | |
| | def render( |
| | self, |
| | verts_list: List[torch.Tensor], |
| | faces_list: List[torch.Tensor], |
| | colors_list: List[torch.Tensor], |
| | ) -> Tuple[np.ndarray, np.ndarray]: |
| | """ |
| | Render a batch of meshes into an RGB image and mask. |
| | |
| | Args: |
| | verts_list (List[torch.Tensor]): List of vertex tensors. |
| | faces_list (List[torch.Tensor]): List of face tensors. |
| | colors_list (List[torch.Tensor]): List of per-vertex color tensors. |
| | |
| | Returns: |
| | rend (np.ndarray): Rendered RGB image as uint8 array. |
| | mask (np.ndarray): Boolean mask of rendered pixels. |
| | """ |
| | all_verts = [] |
| | all_faces = [] |
| | all_colors = [] |
| | vertex_offset = 0 |
| |
|
| | for verts, faces, colors in zip(verts_list, faces_list, colors_list): |
| | all_verts.append(verts) |
| | all_colors.append(colors) |
| | all_faces.append(faces + vertex_offset) |
| | vertex_offset += verts.shape[0] |
| |
|
| | |
| | all_verts = torch.cat(all_verts, dim=0) |
| | all_faces = torch.cat(all_faces, dim=0) |
| | all_colors = torch.cat(all_colors, dim=0) |
| |
|
| | mesh = Meshes( |
| | verts=[all_verts], |
| | faces=[all_faces], |
| | textures=TexturesVertex(all_colors.unsqueeze(0)), |
| | ) |
| |
|
| | |
| | images = self.renderer(mesh, cameras=self.cameras, lights=self.lights) |
| |
|
| | rend = np.clip(images[0, ..., :3].cpu().numpy() * 255, 0, 255).astype(np.uint8) |
| | mask = images[0, ..., -1].cpu().numpy() > 0 |
| |
|
| | return rend, mask |