| import cv2 |
| import numpy as np |
| import torch |
| from jaxtyping import Float |
|
|
|
|
| def read_img(path): |
| img = cv2.imread(str(path), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) |
| if img.ndim == 3: |
| img = cv2.cvtColor(img[..., :3], cv2.COLOR_BGR2RGB) |
| elif img.ndim == 2: |
| img = img[..., np.newaxis] |
| dinfo = np.iinfo(img.dtype) |
| return (img.astype(np.float32) / dinfo.max) * 2 - 1 |
|
|
|
|
| def write_img(path: str, data: np.ndarray): |
| data = np.clip(data * 0.5 + 0.5, 0, 1) |
| if data.ndim == 3 and data.shape[-1] == 3: |
| data = cv2.cvtColor(data, cv2.COLOR_RGB2BGR) |
| elif data.ndim == 2: |
| data = data[..., np.newaxis] |
|
|
| data = (data * 255).astype(np.uint8) |
| cv2.imwrite(path, data) |
|
|
|
|
| def to_torch(img: Float[np.ndarray, "H W C"]) -> Float[torch.Tensor, "C H W"]: |
| return torch.from_numpy(img).permute(2, 0, 1) |
|
|
|
|
| def from_torch(img: Float[torch.Tensor, "C H W"]) -> Float[np.ndarray, "H W C"]: |
| return img.permute(1, 2, 0).detach().cpu().float().numpy() |
|
|