LandmarkDiff / landmarkdiff /augmentation.py
dreamlessx's picture
Update landmarkdiff/augmentation.py to v0.3.2
6071b18 verified
"""Training data augmentation pipeline for LandmarkDiff.
Provides domain-specific augmentations that maintain landmark consistency:
- Geometric: flip, rotation, affine (landmarks co-transformed)
- Photometric: color jitter, brightness, contrast (applied to images only)
- Skin-tone augmentation: ITA-space perturbation for Fitzpatrick balance
- Conditioning augmentation: noise injection, dropout for robustness
All augmentations preserve the correspondence between:
input_image ↔ conditioning_image ↔ target_image ↔ mask
"""
from __future__ import annotations
from dataclasses import dataclass
import cv2
import numpy as np
@dataclass
class AugmentationConfig:
"""Augmentation parameters."""
# Geometric
random_flip: bool = True
random_rotation_deg: float = 5.0
random_scale: tuple[float, float] = (0.95, 1.05)
random_translate: float = 0.02 # fraction of image size
# Photometric (images only, not conditioning)
brightness_range: tuple[float, float] = (0.9, 1.1)
contrast_range: tuple[float, float] = (0.9, 1.1)
saturation_range: tuple[float, float] = (0.9, 1.1)
hue_shift_range: float = 5.0 # degrees
# Conditioning augmentation
conditioning_dropout_prob: float = 0.1
conditioning_noise_std: float = 0.02
# Skin-tone augmentation
ita_perturbation_std: float = 3.0 # ITA angle noise
seed: int | None = None
def augment_training_sample(
input_image: np.ndarray,
target_image: np.ndarray,
conditioning: np.ndarray,
mask: np.ndarray,
landmarks_src: np.ndarray | None = None,
landmarks_dst: np.ndarray | None = None,
config: AugmentationConfig | None = None,
rng: np.random.Generator | None = None,
) -> dict[str, np.ndarray]:
"""Apply consistent augmentations to a training sample.
All spatial transforms are applied to images AND landmarks together
so correspondence is preserved.
Args:
input_image: (H, W, 3) original face image (uint8 BGR).
target_image: (H, W, 3) target face image (uint8 BGR).
conditioning: (H, W, 3) conditioning image (uint8).
mask: (H, W) or (H, W, 1) float32 mask.
landmarks_src: (N, 2) normalized [0,1] source landmark coords.
landmarks_dst: (N, 2) normalized [0,1] target landmark coords.
config: Augmentation parameters.
rng: Random generator for reproducibility.
Returns:
Dict with augmented versions of all inputs.
"""
if config is None:
config = AugmentationConfig()
if rng is None:
rng = np.random.default_rng(config.seed)
h, w = input_image.shape[:2]
out_input = input_image.copy()
out_target = target_image.copy()
out_cond = conditioning.copy()
out_mask = mask.copy()
out_lm_src = landmarks_src.copy() if landmarks_src is not None else None
out_lm_dst = landmarks_dst.copy() if landmarks_dst is not None else None
# --- Geometric augmentations (applied to all) ---
# Random horizontal flip
if config.random_flip and rng.random() < 0.5:
out_input = np.ascontiguousarray(out_input[:, ::-1])
out_target = np.ascontiguousarray(out_target[:, ::-1])
out_cond = np.ascontiguousarray(out_cond[:, ::-1])
out_mask = np.ascontiguousarray(out_mask[:, ::-1] if out_mask.ndim == 2
else out_mask[:, ::-1, :])
if out_lm_src is not None:
out_lm_src[:, 0] = 1.0 - out_lm_src[:, 0]
if out_lm_dst is not None:
out_lm_dst[:, 0] = 1.0 - out_lm_dst[:, 0]
# Random rotation + scale + translate
if config.random_rotation_deg > 0 or config.random_scale != (1.0, 1.0):
angle = rng.uniform(-config.random_rotation_deg, config.random_rotation_deg)
scale = rng.uniform(config.random_scale[0], config.random_scale[1])
tx = rng.uniform(-config.random_translate, config.random_translate) * w
ty = rng.uniform(-config.random_translate, config.random_translate) * h
center = (w / 2, h / 2)
M = cv2.getRotationMatrix2D(center, angle, scale)
M[0, 2] += tx
M[1, 2] += ty
out_input = cv2.warpAffine(out_input, M, (w, h),
borderMode=cv2.BORDER_REFLECT_101)
out_target = cv2.warpAffine(out_target, M, (w, h),
borderMode=cv2.BORDER_REFLECT_101)
out_cond = cv2.warpAffine(out_cond, M, (w, h),
borderMode=cv2.BORDER_CONSTANT, borderValue=0)
mask_2d = out_mask if out_mask.ndim == 2 else out_mask[:, :, 0]
mask_2d = cv2.warpAffine(mask_2d, M, (w, h),
borderMode=cv2.BORDER_CONSTANT, borderValue=0)
out_mask = mask_2d if out_mask.ndim == 2 else mask_2d[:, :, np.newaxis]
# Transform landmarks
if out_lm_src is not None:
out_lm_src = _transform_landmarks(out_lm_src, M, w, h)
if out_lm_dst is not None:
out_lm_dst = _transform_landmarks(out_lm_dst, M, w, h)
# --- Photometric augmentations (images only, not conditioning/mask) ---
# Brightness
b_factor = rng.uniform(config.brightness_range[0], config.brightness_range[1])
out_input = np.clip(out_input.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
out_target = np.clip(out_target.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
# Contrast
c_factor = rng.uniform(config.contrast_range[0], config.contrast_range[1])
mean_in = out_input.mean()
mean_tgt = out_target.mean()
out_input = np.clip(
(out_input.astype(np.float32) - mean_in) * c_factor + mean_in, 0, 255
).astype(np.uint8)
out_target = np.clip(
(out_target.astype(np.float32) - mean_tgt) * c_factor + mean_tgt, 0, 255
).astype(np.uint8)
# Saturation (in HSV space)
s_factor = rng.uniform(config.saturation_range[0], config.saturation_range[1])
if abs(s_factor - 1.0) > 1e-4:
out_input = _adjust_saturation(out_input, s_factor)
out_target = _adjust_saturation(out_target, s_factor)
# Hue shift
if config.hue_shift_range > 0:
hue_delta = rng.uniform(-config.hue_shift_range, config.hue_shift_range)
if abs(hue_delta) > 0.1:
out_input = _shift_hue(out_input, hue_delta)
out_target = _shift_hue(out_target, hue_delta)
# --- Conditioning augmentation ---
# Conditioning dropout (replace with zeros to learn unconditional)
if config.conditioning_dropout_prob > 0 and rng.random() < config.conditioning_dropout_prob:
out_cond = np.zeros_like(out_cond)
# Conditioning noise
if config.conditioning_noise_std > 0:
noise = rng.normal(0, config.conditioning_noise_std * 255, out_cond.shape)
out_cond = np.clip(out_cond.astype(np.float32) + noise, 0, 255).astype(np.uint8)
result = {
"input_image": out_input,
"target_image": out_target,
"conditioning": out_cond,
"mask": out_mask,
}
if out_lm_src is not None:
result["landmarks_src"] = out_lm_src
if out_lm_dst is not None:
result["landmarks_dst"] = out_lm_dst
return result
def _transform_landmarks(
landmarks: np.ndarray, M: np.ndarray, w: int, h: int
) -> np.ndarray:
"""Transform normalized landmarks with an affine matrix."""
# Convert to pixel coords
px = landmarks.copy()
px[:, 0] *= w
px[:, 1] *= h
# Apply affine transform
ones = np.ones((px.shape[0], 1))
px_h = np.hstack([px, ones]) # (N, 3)
transformed = (M @ px_h.T).T # (N, 2)
# Back to normalized
transformed[:, 0] /= w
transformed[:, 1] /= h
return np.clip(transformed, 0.0, 1.0)
def _adjust_saturation(img: np.ndarray, factor: float) -> np.ndarray:
"""Adjust saturation of a BGR image."""
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)
return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
def _shift_hue(img: np.ndarray, delta_deg: float) -> np.ndarray:
"""Shift hue of a BGR image by delta degrees."""
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
# OpenCV hue range is [0, 180]
hsv[:, :, 0] = (hsv[:, :, 0] + delta_deg / 2) % 180
return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
def augment_skin_tone(
image: np.ndarray,
ita_delta: float = 0.0,
) -> np.ndarray:
"""Augment skin tone by shifting in L*a*b* space.
This helps balance Fitzpatrick representation in training by
simulating different skin tones from existing samples.
Args:
image: (H, W, 3) BGR uint8 image.
ita_delta: ITA angle shift (positive = lighter, negative = darker).
Returns:
Augmented image with shifted skin tone.
"""
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
# Shift L channel (lightness) based on ITA delta
# ITA = arctan((L-50)/b), so shifting ITA shifts L
l_shift = ita_delta * 0.5 # approximate mapping
lab[:, :, 0] = np.clip(lab[:, :, 0] + l_shift, 0, 255)
# Slightly shift b channel too for more natural tone changes
b_shift = -ita_delta * 0.15
lab[:, :, 2] = np.clip(lab[:, :, 2] + b_shift, 0, 255)
return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
class FitzpatrickBalancer:
"""Oversample underrepresented Fitzpatrick types during training.
Maintains per-type counts and generates sampling weights to ensure
equitable training across all skin types.
"""
def __init__(self, target_distribution: dict[str, float] | None = None):
"""Initialize balancer.
Args:
target_distribution: Target fraction per type. Defaults to uniform.
"""
self.target = target_distribution or {
"I": 1/6, "II": 1/6, "III": 1/6,
"IV": 1/6, "V": 1/6, "VI": 1/6,
}
self._counts: dict[str, int] = {}
def register_sample(self, fitz_type: str) -> None:
"""Register a sample's Fitzpatrick type."""
self._counts[fitz_type] = self._counts.get(fitz_type, 0) + 1
def get_sampling_weights(self, fitz_types: list[str]) -> np.ndarray:
"""Compute sampling weights for a list of samples.
Returns weights inversely proportional to type frequency,
so underrepresented types get upsampled.
"""
total = sum(self._counts.values()) or 1
weights = []
for ft in fitz_types:
count = self._counts.get(ft, 1)
freq = count / total
target_freq = self.target.get(ft, 1/6)
# Weight = target / actual (capped for stability)
w = min(target_freq / max(freq, 1e-6), 5.0)
weights.append(w)
w = np.array(weights, dtype=np.float64)
return w / w.sum() # normalize to probability distribution