Spaces:
Running
Running
| """Training data augmentation pipeline for LandmarkDiff. | |
| Provides domain-specific augmentations that maintain landmark consistency: | |
| - Geometric: flip, rotation, affine (landmarks co-transformed) | |
| - Photometric: color jitter, brightness, contrast (applied to images only) | |
| - Skin-tone augmentation: ITA-space perturbation for Fitzpatrick balance | |
| - Conditioning augmentation: noise injection, dropout for robustness | |
| All augmentations preserve the correspondence between: | |
| input_image ↔ conditioning_image ↔ target_image ↔ mask | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import cv2 | |
| import numpy as np | |
| class AugmentationConfig: | |
| """Augmentation parameters.""" | |
| # Geometric | |
| random_flip: bool = True | |
| random_rotation_deg: float = 5.0 | |
| random_scale: tuple[float, float] = (0.95, 1.05) | |
| random_translate: float = 0.02 # fraction of image size | |
| # Photometric (images only, not conditioning) | |
| brightness_range: tuple[float, float] = (0.9, 1.1) | |
| contrast_range: tuple[float, float] = (0.9, 1.1) | |
| saturation_range: tuple[float, float] = (0.9, 1.1) | |
| hue_shift_range: float = 5.0 # degrees | |
| # Conditioning augmentation | |
| conditioning_dropout_prob: float = 0.1 | |
| conditioning_noise_std: float = 0.02 | |
| # Skin-tone augmentation | |
| ita_perturbation_std: float = 3.0 # ITA angle noise | |
| seed: int | None = None | |
| def augment_training_sample( | |
| input_image: np.ndarray, | |
| target_image: np.ndarray, | |
| conditioning: np.ndarray, | |
| mask: np.ndarray, | |
| landmarks_src: np.ndarray | None = None, | |
| landmarks_dst: np.ndarray | None = None, | |
| config: AugmentationConfig | None = None, | |
| rng: np.random.Generator | None = None, | |
| ) -> dict[str, np.ndarray]: | |
| """Apply consistent augmentations to a training sample. | |
| All spatial transforms are applied to images AND landmarks together | |
| so correspondence is preserved. | |
| Args: | |
| input_image: (H, W, 3) original face image (uint8 BGR). | |
| target_image: (H, W, 3) target face image (uint8 BGR). | |
| conditioning: (H, W, 3) conditioning image (uint8). | |
| mask: (H, W) or (H, W, 1) float32 mask. | |
| landmarks_src: (N, 2) normalized [0,1] source landmark coords. | |
| landmarks_dst: (N, 2) normalized [0,1] target landmark coords. | |
| config: Augmentation parameters. | |
| rng: Random generator for reproducibility. | |
| Returns: | |
| Dict with augmented versions of all inputs. | |
| """ | |
| if config is None: | |
| config = AugmentationConfig() | |
| if rng is None: | |
| rng = np.random.default_rng(config.seed) | |
| h, w = input_image.shape[:2] | |
| out_input = input_image.copy() | |
| out_target = target_image.copy() | |
| out_cond = conditioning.copy() | |
| out_mask = mask.copy() | |
| out_lm_src = landmarks_src.copy() if landmarks_src is not None else None | |
| out_lm_dst = landmarks_dst.copy() if landmarks_dst is not None else None | |
| # --- Geometric augmentations (applied to all) --- | |
| # Random horizontal flip | |
| if config.random_flip and rng.random() < 0.5: | |
| out_input = np.ascontiguousarray(out_input[:, ::-1]) | |
| out_target = np.ascontiguousarray(out_target[:, ::-1]) | |
| out_cond = np.ascontiguousarray(out_cond[:, ::-1]) | |
| out_mask = np.ascontiguousarray(out_mask[:, ::-1] if out_mask.ndim == 2 | |
| else out_mask[:, ::-1, :]) | |
| if out_lm_src is not None: | |
| out_lm_src[:, 0] = 1.0 - out_lm_src[:, 0] | |
| if out_lm_dst is not None: | |
| out_lm_dst[:, 0] = 1.0 - out_lm_dst[:, 0] | |
| # Random rotation + scale + translate | |
| if config.random_rotation_deg > 0 or config.random_scale != (1.0, 1.0): | |
| angle = rng.uniform(-config.random_rotation_deg, config.random_rotation_deg) | |
| scale = rng.uniform(config.random_scale[0], config.random_scale[1]) | |
| tx = rng.uniform(-config.random_translate, config.random_translate) * w | |
| ty = rng.uniform(-config.random_translate, config.random_translate) * h | |
| center = (w / 2, h / 2) | |
| M = cv2.getRotationMatrix2D(center, angle, scale) | |
| M[0, 2] += tx | |
| M[1, 2] += ty | |
| out_input = cv2.warpAffine(out_input, M, (w, h), | |
| borderMode=cv2.BORDER_REFLECT_101) | |
| out_target = cv2.warpAffine(out_target, M, (w, h), | |
| borderMode=cv2.BORDER_REFLECT_101) | |
| out_cond = cv2.warpAffine(out_cond, M, (w, h), | |
| borderMode=cv2.BORDER_CONSTANT, borderValue=0) | |
| mask_2d = out_mask if out_mask.ndim == 2 else out_mask[:, :, 0] | |
| mask_2d = cv2.warpAffine(mask_2d, M, (w, h), | |
| borderMode=cv2.BORDER_CONSTANT, borderValue=0) | |
| out_mask = mask_2d if out_mask.ndim == 2 else mask_2d[:, :, np.newaxis] | |
| # Transform landmarks | |
| if out_lm_src is not None: | |
| out_lm_src = _transform_landmarks(out_lm_src, M, w, h) | |
| if out_lm_dst is not None: | |
| out_lm_dst = _transform_landmarks(out_lm_dst, M, w, h) | |
| # --- Photometric augmentations (images only, not conditioning/mask) --- | |
| # Brightness | |
| b_factor = rng.uniform(config.brightness_range[0], config.brightness_range[1]) | |
| out_input = np.clip(out_input.astype(np.float32) * b_factor, 0, 255).astype(np.uint8) | |
| out_target = np.clip(out_target.astype(np.float32) * b_factor, 0, 255).astype(np.uint8) | |
| # Contrast | |
| c_factor = rng.uniform(config.contrast_range[0], config.contrast_range[1]) | |
| mean_in = out_input.mean() | |
| mean_tgt = out_target.mean() | |
| out_input = np.clip( | |
| (out_input.astype(np.float32) - mean_in) * c_factor + mean_in, 0, 255 | |
| ).astype(np.uint8) | |
| out_target = np.clip( | |
| (out_target.astype(np.float32) - mean_tgt) * c_factor + mean_tgt, 0, 255 | |
| ).astype(np.uint8) | |
| # Saturation (in HSV space) | |
| s_factor = rng.uniform(config.saturation_range[0], config.saturation_range[1]) | |
| if abs(s_factor - 1.0) > 1e-4: | |
| out_input = _adjust_saturation(out_input, s_factor) | |
| out_target = _adjust_saturation(out_target, s_factor) | |
| # Hue shift | |
| if config.hue_shift_range > 0: | |
| hue_delta = rng.uniform(-config.hue_shift_range, config.hue_shift_range) | |
| if abs(hue_delta) > 0.1: | |
| out_input = _shift_hue(out_input, hue_delta) | |
| out_target = _shift_hue(out_target, hue_delta) | |
| # --- Conditioning augmentation --- | |
| # Conditioning dropout (replace with zeros to learn unconditional) | |
| if config.conditioning_dropout_prob > 0 and rng.random() < config.conditioning_dropout_prob: | |
| out_cond = np.zeros_like(out_cond) | |
| # Conditioning noise | |
| if config.conditioning_noise_std > 0: | |
| noise = rng.normal(0, config.conditioning_noise_std * 255, out_cond.shape) | |
| out_cond = np.clip(out_cond.astype(np.float32) + noise, 0, 255).astype(np.uint8) | |
| result = { | |
| "input_image": out_input, | |
| "target_image": out_target, | |
| "conditioning": out_cond, | |
| "mask": out_mask, | |
| } | |
| if out_lm_src is not None: | |
| result["landmarks_src"] = out_lm_src | |
| if out_lm_dst is not None: | |
| result["landmarks_dst"] = out_lm_dst | |
| return result | |
| def _transform_landmarks( | |
| landmarks: np.ndarray, M: np.ndarray, w: int, h: int | |
| ) -> np.ndarray: | |
| """Transform normalized landmarks with an affine matrix.""" | |
| # Convert to pixel coords | |
| px = landmarks.copy() | |
| px[:, 0] *= w | |
| px[:, 1] *= h | |
| # Apply affine transform | |
| ones = np.ones((px.shape[0], 1)) | |
| px_h = np.hstack([px, ones]) # (N, 3) | |
| transformed = (M @ px_h.T).T # (N, 2) | |
| # Back to normalized | |
| transformed[:, 0] /= w | |
| transformed[:, 1] /= h | |
| return np.clip(transformed, 0.0, 1.0) | |
| def _adjust_saturation(img: np.ndarray, factor: float) -> np.ndarray: | |
| """Adjust saturation of a BGR image.""" | |
| hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32) | |
| hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255) | |
| return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR) | |
| def _shift_hue(img: np.ndarray, delta_deg: float) -> np.ndarray: | |
| """Shift hue of a BGR image by delta degrees.""" | |
| hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32) | |
| # OpenCV hue range is [0, 180] | |
| hsv[:, :, 0] = (hsv[:, :, 0] + delta_deg / 2) % 180 | |
| return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR) | |
| def augment_skin_tone( | |
| image: np.ndarray, | |
| ita_delta: float = 0.0, | |
| ) -> np.ndarray: | |
| """Augment skin tone by shifting in L*a*b* space. | |
| This helps balance Fitzpatrick representation in training by | |
| simulating different skin tones from existing samples. | |
| Args: | |
| image: (H, W, 3) BGR uint8 image. | |
| ita_delta: ITA angle shift (positive = lighter, negative = darker). | |
| Returns: | |
| Augmented image with shifted skin tone. | |
| """ | |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32) | |
| # Shift L channel (lightness) based on ITA delta | |
| # ITA = arctan((L-50)/b), so shifting ITA shifts L | |
| l_shift = ita_delta * 0.5 # approximate mapping | |
| lab[:, :, 0] = np.clip(lab[:, :, 0] + l_shift, 0, 255) | |
| # Slightly shift b channel too for more natural tone changes | |
| b_shift = -ita_delta * 0.15 | |
| lab[:, :, 2] = np.clip(lab[:, :, 2] + b_shift, 0, 255) | |
| return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR) | |
| class FitzpatrickBalancer: | |
| """Oversample underrepresented Fitzpatrick types during training. | |
| Maintains per-type counts and generates sampling weights to ensure | |
| equitable training across all skin types. | |
| """ | |
| def __init__(self, target_distribution: dict[str, float] | None = None): | |
| """Initialize balancer. | |
| Args: | |
| target_distribution: Target fraction per type. Defaults to uniform. | |
| """ | |
| self.target = target_distribution or { | |
| "I": 1/6, "II": 1/6, "III": 1/6, | |
| "IV": 1/6, "V": 1/6, "VI": 1/6, | |
| } | |
| self._counts: dict[str, int] = {} | |
| def register_sample(self, fitz_type: str) -> None: | |
| """Register a sample's Fitzpatrick type.""" | |
| self._counts[fitz_type] = self._counts.get(fitz_type, 0) + 1 | |
| def get_sampling_weights(self, fitz_types: list[str]) -> np.ndarray: | |
| """Compute sampling weights for a list of samples. | |
| Returns weights inversely proportional to type frequency, | |
| so underrepresented types get upsampled. | |
| """ | |
| total = sum(self._counts.values()) or 1 | |
| weights = [] | |
| for ft in fitz_types: | |
| count = self._counts.get(ft, 1) | |
| freq = count / total | |
| target_freq = self.target.get(ft, 1/6) | |
| # Weight = target / actual (capped for stability) | |
| w = min(target_freq / max(freq, 1e-6), 5.0) | |
| weights.append(w) | |
| w = np.array(weights, dtype=np.float64) | |
| return w / w.sum() # normalize to probability distribution | |