Spaces:
Runtime error
Runtime error
Vedant Jigarbhai Mehta commited on
Commit ·
b25c087
0
Parent(s):
Initial scaffolding for military base change detection project
Browse filesAdd complete project structure with 3 model architectures (Siamese CNN,
UNet++, ChangeFormer), dataset pipeline, training/evaluation/inference
scripts, Gradio demo, Colab setup, and config with all hyperparameters.
- .gitignore +42 -0
- README.md +93 -0
- app.py +193 -0
- configs/config.yaml +143 -0
- data/dataset.py +153 -0
- data/download.py +132 -0
- evaluate.py +135 -0
- inference.py +176 -0
- models/__init__.py +39 -0
- models/changeformer.py +358 -0
- models/siamese_cnn.py +85 -0
- models/unet_pp.py +78 -0
- requirements.txt +16 -0
- setup_colab.py +172 -0
- train.py +418 -0
- utils/__init__.py +0 -0
- utils/losses.py +139 -0
- utils/metrics.py +226 -0
- utils/visualization.py +141 -0
.gitignore
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Checkpoints & model weights
|
| 2 |
+
checkpoints/
|
| 3 |
+
*.pth
|
| 4 |
+
*.pt
|
| 5 |
+
|
| 6 |
+
# Logs
|
| 7 |
+
logs/
|
| 8 |
+
|
| 9 |
+
# Outputs
|
| 10 |
+
outputs/
|
| 11 |
+
|
| 12 |
+
# Data
|
| 13 |
+
raw_data/
|
| 14 |
+
processed_data/
|
| 15 |
+
|
| 16 |
+
# Python
|
| 17 |
+
__pycache__/
|
| 18 |
+
*.pyc
|
| 19 |
+
*.pyo
|
| 20 |
+
*.egg-info/
|
| 21 |
+
dist/
|
| 22 |
+
build/
|
| 23 |
+
.eggs/
|
| 24 |
+
|
| 25 |
+
# Environment
|
| 26 |
+
.env
|
| 27 |
+
.venv/
|
| 28 |
+
venv/
|
| 29 |
+
env/
|
| 30 |
+
|
| 31 |
+
# IDE
|
| 32 |
+
.vscode/
|
| 33 |
+
.idea/
|
| 34 |
+
*.swp
|
| 35 |
+
*.swo
|
| 36 |
+
|
| 37 |
+
# OS
|
| 38 |
+
.DS_Store
|
| 39 |
+
Thumbs.db
|
| 40 |
+
|
| 41 |
+
# Jupyter
|
| 42 |
+
.ipynb_checkpoints/
|
README.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Military Base Construction Monitoring — Change Detection
|
| 2 |
+
|
| 3 |
+
Deep learning system for detecting new structures and infrastructure changes between satellite image pairs. Targets defense applications: military base expansion, runway construction, and infrastructure development monitoring.
|
| 4 |
+
|
| 5 |
+
## Models
|
| 6 |
+
|
| 7 |
+
| Model | Backbone | Role | Paper |
|
| 8 |
+
|---|---|---|---|
|
| 9 |
+
| Siamese CNN | ResNet18 (shared) | Baseline | — |
|
| 10 |
+
| UNet++ | ResNet34 (shared) | Mid-tier | [arXiv:1807.10165](https://arxiv.org/abs/1807.10165) |
|
| 11 |
+
| ChangeFormer | MiT-B1 (shared) | SOTA | [arXiv:2201.01293](https://arxiv.org/abs/2201.01293) |
|
| 12 |
+
|
| 13 |
+
## Dataset
|
| 14 |
+
|
| 15 |
+
**LEVIR-CD** — 637 image pairs at 1024×1024, cropped to 256×256 non-overlapping patches. Contains building change annotations across urban areas.
|
| 16 |
+
|
| 17 |
+
## Quick Start (Google Colab)
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
# 1. Setup
|
| 21 |
+
from setup_colab import setup
|
| 22 |
+
dirs = setup()
|
| 23 |
+
|
| 24 |
+
# 2. Train
|
| 25 |
+
!python train.py --config configs/config.yaml --model siamese_cnn
|
| 26 |
+
|
| 27 |
+
# 3. Evaluate
|
| 28 |
+
!python evaluate.py --config configs/config.yaml --checkpoint checkpoints/siamese_cnn_best.pth
|
| 29 |
+
|
| 30 |
+
# 4. Resume after disconnect
|
| 31 |
+
!python train.py --config configs/config.yaml --model changeformer \
|
| 32 |
+
--resume /content/drive/MyDrive/change-detection/checkpoints/changeformer_last.pth
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
## Local Usage
|
| 36 |
+
|
| 37 |
+
```bash
|
| 38 |
+
# Preprocess dataset
|
| 39 |
+
python data/download.py --dataset levir-cd --raw_dir ./raw_data --out_dir ./processed_data
|
| 40 |
+
|
| 41 |
+
# Train
|
| 42 |
+
python train.py --config configs/config.yaml --model unet_pp
|
| 43 |
+
|
| 44 |
+
# Evaluate
|
| 45 |
+
python evaluate.py --config configs/config.yaml --checkpoint checkpoints/unet_pp_best.pth
|
| 46 |
+
|
| 47 |
+
# Inference on new image pair
|
| 48 |
+
python inference.py --before path/to/before.png --after path/to/after.png \
|
| 49 |
+
--model changeformer --checkpoint checkpoints/changeformer_best.pth
|
| 50 |
+
|
| 51 |
+
# Gradio demo
|
| 52 |
+
python app.py
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## GPU Batch Sizes (Auto-Detected)
|
| 56 |
+
|
| 57 |
+
| Model | T4 (16GB) | V100 (16GB) | LR |
|
| 58 |
+
|---|---|---|---|
|
| 59 |
+
| Siamese CNN | 16 | 16 | 1e-3 |
|
| 60 |
+
| UNet++ | 8 | 12 | 1e-4 |
|
| 61 |
+
| ChangeFormer | 4 | 6 | 6e-5 |
|
| 62 |
+
|
| 63 |
+
## Evaluation Metrics
|
| 64 |
+
|
| 65 |
+
- **F1-Score** (primary, used for model selection and early stopping)
|
| 66 |
+
- IoU / Jaccard
|
| 67 |
+
- Precision, Recall
|
| 68 |
+
- Overall Accuracy
|
| 69 |
+
|
| 70 |
+
## Project Structure
|
| 71 |
+
|
| 72 |
+
```
|
| 73 |
+
military-base-change-detection/
|
| 74 |
+
├── configs/config.yaml # All hyperparameters and paths
|
| 75 |
+
├── data/
|
| 76 |
+
│ ├── download.py # Dataset download & patch cropping
|
| 77 |
+
│ └── dataset.py # PyTorch Dataset with synced augmentations
|
| 78 |
+
├── models/
|
| 79 |
+
│ ├── __init__.py # get_model() factory
|
| 80 |
+
│ ├── siamese_cnn.py # Siamese CNN baseline
|
| 81 |
+
│ ├── unet_pp.py # UNet++ change detection
|
| 82 |
+
│ └── changeformer.py # ChangeFormer transformer
|
| 83 |
+
├── utils/
|
| 84 |
+
│ ├── metrics.py # F1, IoU, Precision, Recall, OA
|
| 85 |
+
│ ├── losses.py # BCEDiceLoss, FocalLoss
|
| 86 |
+
│ └── visualization.py # Plotting utilities
|
| 87 |
+
├── train.py # Training with AMP, early stopping, resume
|
| 88 |
+
├── evaluate.py # Test set evaluation
|
| 89 |
+
├── inference.py # Inference on new image pairs
|
| 90 |
+
├── app.py # Gradio demo
|
| 91 |
+
├── setup_colab.py # Colab environment setup
|
| 92 |
+
└── requirements.txt # Pinned dependencies
|
| 93 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio web demo for change detection inference.
|
| 2 |
+
|
| 3 |
+
Provides an interactive interface to upload before/after satellite image pairs
|
| 4 |
+
and visualize predicted change masks with overlays.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python app.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import cv2
|
| 15 |
+
import gradio as gr
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import yaml
|
| 19 |
+
|
| 20 |
+
from data.dataset import IMAGENET_MEAN, IMAGENET_STD
|
| 21 |
+
from inference import preprocess_image, sliding_window_inference
|
| 22 |
+
from models import get_model
|
| 23 |
+
from utils.visualization import denormalize, overlay_changes
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
# Global model cache
|
| 28 |
+
_model: Optional[torch.nn.Module] = None
|
| 29 |
+
_model_name: Optional[str] = None
|
| 30 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
_config = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_config() -> dict:
|
| 35 |
+
"""Load project config from YAML.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Config dictionary.
|
| 39 |
+
"""
|
| 40 |
+
config_path = Path("configs/config.yaml")
|
| 41 |
+
with open(config_path, "r") as f:
|
| 42 |
+
return yaml.safe_load(f)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_model(model_name: str, checkpoint_path: str) -> torch.nn.Module:
|
| 46 |
+
"""Load a change detection model with caching.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
model_name: Name of the model architecture.
|
| 50 |
+
checkpoint_path: Path to the model checkpoint.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Loaded model in eval mode.
|
| 54 |
+
"""
|
| 55 |
+
global _model, _model_name, _config
|
| 56 |
+
|
| 57 |
+
if _config is None:
|
| 58 |
+
_config = load_config()
|
| 59 |
+
|
| 60 |
+
if _model is not None and _model_name == model_name:
|
| 61 |
+
return _model
|
| 62 |
+
|
| 63 |
+
model = get_model(model_name, _config).to(_device)
|
| 64 |
+
ckpt = torch.load(checkpoint_path, map_location=_device)
|
| 65 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 66 |
+
model.eval()
|
| 67 |
+
|
| 68 |
+
_model = model
|
| 69 |
+
_model_name = model_name
|
| 70 |
+
logger.info("Loaded model: %s from %s", model_name, checkpoint_path)
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def predict(
|
| 75 |
+
before_image: np.ndarray,
|
| 76 |
+
after_image: np.ndarray,
|
| 77 |
+
model_name: str,
|
| 78 |
+
checkpoint_path: str,
|
| 79 |
+
threshold: float,
|
| 80 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 81 |
+
"""Run change detection on a pair of images.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
before_image: Before image as numpy array (RGB, uint8).
|
| 85 |
+
after_image: After image as numpy array (RGB, uint8).
|
| 86 |
+
model_name: Model architecture name.
|
| 87 |
+
checkpoint_path: Path to model weights.
|
| 88 |
+
threshold: Binarization threshold.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Tuple of (binary change mask, overlay visualization).
|
| 92 |
+
"""
|
| 93 |
+
model = load_model(model_name, checkpoint_path)
|
| 94 |
+
patch_size = 256
|
| 95 |
+
|
| 96 |
+
# Preprocess both images
|
| 97 |
+
def _to_tensor(img: np.ndarray) -> torch.Tensor:
|
| 98 |
+
h, w = img.shape[:2]
|
| 99 |
+
pad_h = (patch_size - h % patch_size) % patch_size
|
| 100 |
+
pad_w = (patch_size - w % patch_size) % patch_size
|
| 101 |
+
if pad_h > 0 or pad_w > 0:
|
| 102 |
+
img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
|
| 103 |
+
img_f = img.astype(np.float32) / 255.0
|
| 104 |
+
mean = np.array(IMAGENET_MEAN, dtype=np.float32)
|
| 105 |
+
std = np.array(IMAGENET_STD, dtype=np.float32)
|
| 106 |
+
img_f = (img_f - mean) / std
|
| 107 |
+
return torch.from_numpy(img_f).permute(2, 0, 1).unsqueeze(0).float()
|
| 108 |
+
|
| 109 |
+
orig_h, orig_w = before_image.shape[:2]
|
| 110 |
+
tensor_a = _to_tensor(before_image)
|
| 111 |
+
tensor_b = _to_tensor(after_image)
|
| 112 |
+
|
| 113 |
+
# Run inference
|
| 114 |
+
prob_map = sliding_window_inference(model, tensor_a, tensor_b, patch_size, _device)
|
| 115 |
+
prob_map = prob_map[:, :, :orig_h, :orig_w]
|
| 116 |
+
|
| 117 |
+
# Binary mask
|
| 118 |
+
mask_np = prob_map.squeeze().numpy()
|
| 119 |
+
binary_mask = (mask_np > threshold).astype(np.uint8) * 255
|
| 120 |
+
|
| 121 |
+
# Overlay on after image
|
| 122 |
+
overlay = after_image.copy().astype(np.float32) / 255.0
|
| 123 |
+
change_pixels = mask_np > threshold
|
| 124 |
+
overlay[change_pixels, 0] = np.clip(overlay[change_pixels, 0] * 0.6 + 0.4, 0, 1)
|
| 125 |
+
overlay[change_pixels, 1] = overlay[change_pixels, 1] * 0.6
|
| 126 |
+
overlay[change_pixels, 2] = overlay[change_pixels, 2] * 0.6
|
| 127 |
+
overlay = (overlay * 255).astype(np.uint8)
|
| 128 |
+
|
| 129 |
+
return binary_mask, overlay
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def build_demo() -> gr.Blocks:
|
| 133 |
+
"""Build the Gradio demo interface.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Gradio Blocks application.
|
| 137 |
+
"""
|
| 138 |
+
config = load_config()
|
| 139 |
+
gradio_cfg = config.get("gradio", {})
|
| 140 |
+
|
| 141 |
+
with gr.Blocks(title="Military Base Change Detection") as demo:
|
| 142 |
+
gr.Markdown("# Military Base Change Detection")
|
| 143 |
+
gr.Markdown("Upload before/after satellite image pairs to detect construction and infrastructure changes.")
|
| 144 |
+
|
| 145 |
+
with gr.Row():
|
| 146 |
+
with gr.Column():
|
| 147 |
+
before_img = gr.Image(label="Before Image", type="numpy")
|
| 148 |
+
after_img = gr.Image(label="After Image", type="numpy")
|
| 149 |
+
with gr.Column():
|
| 150 |
+
change_mask = gr.Image(label="Change Mask")
|
| 151 |
+
overlay_img = gr.Image(label="Overlay")
|
| 152 |
+
|
| 153 |
+
with gr.Row():
|
| 154 |
+
model_dropdown = gr.Dropdown(
|
| 155 |
+
choices=["siamese_cnn", "unet_pp", "changeformer"],
|
| 156 |
+
value=gradio_cfg.get("default_model", "unet_pp"),
|
| 157 |
+
label="Model",
|
| 158 |
+
)
|
| 159 |
+
checkpoint_input = gr.Textbox(
|
| 160 |
+
value=gradio_cfg.get("default_checkpoint", "checkpoints/unet_pp_best.pth"),
|
| 161 |
+
label="Checkpoint Path",
|
| 162 |
+
)
|
| 163 |
+
threshold_slider = gr.Slider(
|
| 164 |
+
minimum=0.1, maximum=0.9, value=0.5, step=0.05,
|
| 165 |
+
label="Detection Threshold",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
detect_btn = gr.Button("Detect Changes", variant="primary")
|
| 169 |
+
detect_btn.click(
|
| 170 |
+
fn=predict,
|
| 171 |
+
inputs=[before_img, after_img, model_dropdown, checkpoint_input, threshold_slider],
|
| 172 |
+
outputs=[change_mask, overlay_img],
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return demo
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def main() -> None:
|
| 179 |
+
"""Launch the Gradio demo."""
|
| 180 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 181 |
+
|
| 182 |
+
config = load_config()
|
| 183 |
+
gradio_cfg = config.get("gradio", {})
|
| 184 |
+
|
| 185 |
+
demo = build_demo()
|
| 186 |
+
demo.launch(
|
| 187 |
+
server_port=gradio_cfg.get("server_port", 7860),
|
| 188 |
+
share=gradio_cfg.get("share", False),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
if __name__ == "__main__":
|
| 193 |
+
main()
|
configs/config.yaml
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# Military Base Change Detection — Master Configuration
|
| 3 |
+
# =============================================================================
|
| 4 |
+
|
| 5 |
+
# --- Project paths ---
|
| 6 |
+
project:
|
| 7 |
+
name: "military-base-change-detection"
|
| 8 |
+
seed: 42
|
| 9 |
+
|
| 10 |
+
# --- Colab / runtime settings ---
|
| 11 |
+
colab:
|
| 12 |
+
enabled: true
|
| 13 |
+
drive_root: "/content/drive/MyDrive/change-detection"
|
| 14 |
+
checkpoint_dir: "/content/drive/MyDrive/change-detection/checkpoints"
|
| 15 |
+
log_dir: "/content/drive/MyDrive/change-detection/logs"
|
| 16 |
+
output_dir: "/content/drive/MyDrive/change-detection/outputs"
|
| 17 |
+
data_dir: "/content/drive/MyDrive/change-detection/processed_data"
|
| 18 |
+
|
| 19 |
+
# --- Local paths (used when colab.enabled is false) ---
|
| 20 |
+
paths:
|
| 21 |
+
raw_data: "./raw_data"
|
| 22 |
+
processed_data: "./processed_data"
|
| 23 |
+
checkpoint_dir: "./checkpoints"
|
| 24 |
+
log_dir: "./logs"
|
| 25 |
+
output_dir: "./outputs"
|
| 26 |
+
|
| 27 |
+
# --- Dataset ---
|
| 28 |
+
dataset:
|
| 29 |
+
name: "levir-cd" # levir-cd | whu-cd
|
| 30 |
+
original_size: 1024
|
| 31 |
+
patch_size: 256
|
| 32 |
+
num_workers: 4
|
| 33 |
+
pin_memory: true
|
| 34 |
+
# ImageNet normalization
|
| 35 |
+
mean: [0.485, 0.456, 0.406]
|
| 36 |
+
std: [0.229, 0.224, 0.225]
|
| 37 |
+
|
| 38 |
+
# --- Augmentation (train only) ---
|
| 39 |
+
augmentation:
|
| 40 |
+
enabled: true
|
| 41 |
+
horizontal_flip: 0.5
|
| 42 |
+
vertical_flip: 0.5
|
| 43 |
+
random_rotate_90: 0.5
|
| 44 |
+
color_jitter:
|
| 45 |
+
brightness: 0.2
|
| 46 |
+
contrast: 0.2
|
| 47 |
+
saturation: 0.1
|
| 48 |
+
hue: 0.05
|
| 49 |
+
|
| 50 |
+
# --- Model selection ---
|
| 51 |
+
model:
|
| 52 |
+
name: "unet_pp" # siamese_cnn | unet_pp | changeformer
|
| 53 |
+
|
| 54 |
+
# --- Model-specific configs ---
|
| 55 |
+
siamese_cnn:
|
| 56 |
+
backbone: "resnet18"
|
| 57 |
+
pretrained: true
|
| 58 |
+
|
| 59 |
+
unet_pp:
|
| 60 |
+
encoder_name: "resnet34"
|
| 61 |
+
pretrained: true
|
| 62 |
+
deep_supervision: false
|
| 63 |
+
|
| 64 |
+
changeformer:
|
| 65 |
+
embed_dims: [64, 128, 320, 512] # MiT-B1 style
|
| 66 |
+
num_heads: [1, 2, 5, 8]
|
| 67 |
+
mlp_ratios: [8, 8, 4, 4]
|
| 68 |
+
depths: [2, 2, 2, 2]
|
| 69 |
+
pretrained_backbone: true
|
| 70 |
+
|
| 71 |
+
# --- Training ---
|
| 72 |
+
training:
|
| 73 |
+
epochs: 100 # 200 for changeformer
|
| 74 |
+
optimizer: "adamw"
|
| 75 |
+
learning_rate: 1.0e-4
|
| 76 |
+
weight_decay: 0.01
|
| 77 |
+
scheduler: "cosine"
|
| 78 |
+
warmup_epochs: 5
|
| 79 |
+
grad_clip_max_norm: 1.0
|
| 80 |
+
gradient_accumulation_steps: 1 # set to 2 for changeformer on T4
|
| 81 |
+
amp: true # mixed precision
|
| 82 |
+
early_stopping:
|
| 83 |
+
enabled: true
|
| 84 |
+
patience: 15
|
| 85 |
+
metric: "f1"
|
| 86 |
+
mode: "max"
|
| 87 |
+
log_interval: 10 # log every N batches
|
| 88 |
+
vis_interval: 5 # visualize predictions every N epochs
|
| 89 |
+
|
| 90 |
+
# --- Loss ---
|
| 91 |
+
loss:
|
| 92 |
+
name: "bce_dice" # bce_dice | focal
|
| 93 |
+
bce_dice:
|
| 94 |
+
bce_weight: 0.5
|
| 95 |
+
dice_weight: 0.5
|
| 96 |
+
focal:
|
| 97 |
+
alpha: 0.25
|
| 98 |
+
gamma: 2.0
|
| 99 |
+
|
| 100 |
+
# --- Evaluation ---
|
| 101 |
+
evaluation:
|
| 102 |
+
threshold: 0.5
|
| 103 |
+
metrics:
|
| 104 |
+
- f1
|
| 105 |
+
- iou
|
| 106 |
+
- precision
|
| 107 |
+
- recall
|
| 108 |
+
- oa
|
| 109 |
+
|
| 110 |
+
# --- GPU-specific batch sizes (auto-detected on Colab) ---
|
| 111 |
+
# model_name -> { gpu_type -> batch_size }
|
| 112 |
+
batch_sizes:
|
| 113 |
+
siamese_cnn:
|
| 114 |
+
T4: 16
|
| 115 |
+
V100: 16
|
| 116 |
+
default: 8
|
| 117 |
+
unet_pp:
|
| 118 |
+
T4: 8
|
| 119 |
+
V100: 12
|
| 120 |
+
default: 4
|
| 121 |
+
changeformer:
|
| 122 |
+
T4: 4
|
| 123 |
+
V100: 6
|
| 124 |
+
default: 2
|
| 125 |
+
|
| 126 |
+
# --- Per-model learning rates ---
|
| 127 |
+
learning_rates:
|
| 128 |
+
siamese_cnn: 1.0e-3
|
| 129 |
+
unet_pp: 1.0e-4
|
| 130 |
+
changeformer: 6.0e-5
|
| 131 |
+
|
| 132 |
+
# --- Per-model epoch counts ---
|
| 133 |
+
epoch_counts:
|
| 134 |
+
siamese_cnn: 100
|
| 135 |
+
unet_pp: 100
|
| 136 |
+
changeformer: 200
|
| 137 |
+
|
| 138 |
+
# --- Gradio demo ---
|
| 139 |
+
gradio:
|
| 140 |
+
server_port: 7860
|
| 141 |
+
share: false
|
| 142 |
+
default_model: "unet_pp"
|
| 143 |
+
default_checkpoint: "checkpoints/unet_pp_best.pth"
|
data/dataset.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch Dataset for change detection tasks.
|
| 2 |
+
|
| 3 |
+
Loads pre-cropped 256x256 image patches (before/after) and binary change masks.
|
| 4 |
+
Supports synchronized augmentations via albumentations.ReplayCompose.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import albumentations as A
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# ImageNet normalization constants
|
| 20 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 21 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_train_transforms(config: Dict[str, Any]) -> A.ReplayCompose:
|
| 25 |
+
"""Build training augmentation pipeline with synchronized transforms.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
config: Augmentation config dict from config.yaml.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
ReplayCompose that applies identical spatial transforms to A, B, and mask.
|
| 32 |
+
"""
|
| 33 |
+
aug_cfg = config.get("augmentation", {})
|
| 34 |
+
transforms = []
|
| 35 |
+
|
| 36 |
+
if aug_cfg.get("horizontal_flip", 0) > 0:
|
| 37 |
+
transforms.append(A.HorizontalFlip(p=aug_cfg["horizontal_flip"]))
|
| 38 |
+
|
| 39 |
+
if aug_cfg.get("vertical_flip", 0) > 0:
|
| 40 |
+
transforms.append(A.VerticalFlip(p=aug_cfg["vertical_flip"]))
|
| 41 |
+
|
| 42 |
+
if aug_cfg.get("random_rotate_90", 0) > 0:
|
| 43 |
+
transforms.append(A.RandomRotate90(p=aug_cfg["random_rotate_90"]))
|
| 44 |
+
|
| 45 |
+
transforms.append(A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD))
|
| 46 |
+
|
| 47 |
+
return A.ReplayCompose(
|
| 48 |
+
transforms,
|
| 49 |
+
additional_targets={"image_b": "image", "mask": "mask"},
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_val_transforms() -> A.Compose:
|
| 54 |
+
"""Build validation/test transform pipeline (normalize only).
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Compose with ImageNet normalization only.
|
| 58 |
+
"""
|
| 59 |
+
return A.Compose(
|
| 60 |
+
[A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)],
|
| 61 |
+
additional_targets={"image_b": "image"},
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ChangeDetectionDataset(Dataset):
|
| 66 |
+
"""Dataset for loading change detection image pairs and masks.
|
| 67 |
+
|
| 68 |
+
Expects directory structure:
|
| 69 |
+
root/
|
| 70 |
+
├── A/ # before images
|
| 71 |
+
├── B/ # after images
|
| 72 |
+
└── label/ # binary change masks (0=no change, 255=change)
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
root: Path to the split directory (e.g., processed_data/train).
|
| 76 |
+
split: One of 'train', 'val', 'test'.
|
| 77 |
+
config: Full config dict for augmentation settings.
|
| 78 |
+
transform: Optional override for the transform pipeline.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
root: Path,
|
| 84 |
+
split: str = "train",
|
| 85 |
+
config: Optional[Dict[str, Any]] = None,
|
| 86 |
+
transform: Optional[Any] = None,
|
| 87 |
+
) -> None:
|
| 88 |
+
self.root = Path(root)
|
| 89 |
+
self.split = split
|
| 90 |
+
|
| 91 |
+
self.dir_a = self.root / "A"
|
| 92 |
+
self.dir_b = self.root / "B"
|
| 93 |
+
self.dir_label = self.root / "label"
|
| 94 |
+
|
| 95 |
+
# Collect sorted file lists
|
| 96 |
+
self.filenames = sorted([f.name for f in self.dir_a.iterdir() if f.suffix in (".png", ".jpg", ".tif")])
|
| 97 |
+
logger.info("Loaded %d samples for split '%s' from %s", len(self.filenames), split, root)
|
| 98 |
+
|
| 99 |
+
# Set up transforms
|
| 100 |
+
if transform is not None:
|
| 101 |
+
self.transform = transform
|
| 102 |
+
elif split == "train" and config is not None:
|
| 103 |
+
self.transform = get_train_transforms(config)
|
| 104 |
+
else:
|
| 105 |
+
self.transform = get_val_transforms()
|
| 106 |
+
|
| 107 |
+
def __len__(self) -> int:
|
| 108 |
+
return len(self.filenames)
|
| 109 |
+
|
| 110 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 111 |
+
"""Load a single sample.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
idx: Sample index.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Dict with keys 'A', 'B', 'mask', 'filename'.
|
| 118 |
+
- A: before image tensor [3, H, W]
|
| 119 |
+
- B: after image tensor [3, H, W]
|
| 120 |
+
- mask: binary change mask tensor [1, H, W] (float, 0 or 1)
|
| 121 |
+
- filename: original filename string
|
| 122 |
+
"""
|
| 123 |
+
fname = self.filenames[idx]
|
| 124 |
+
|
| 125 |
+
# Lazy load — read from disk each time (no RAM caching)
|
| 126 |
+
img_a = cv2.imread(str(self.dir_a / fname), cv2.IMREAD_COLOR)
|
| 127 |
+
img_a = cv2.cvtColor(img_a, cv2.COLOR_BGR2RGB)
|
| 128 |
+
|
| 129 |
+
img_b = cv2.imread(str(self.dir_b / fname), cv2.IMREAD_COLOR)
|
| 130 |
+
img_b = cv2.cvtColor(img_b, cv2.COLOR_BGR2RGB)
|
| 131 |
+
|
| 132 |
+
mask = cv2.imread(str(self.dir_label / fname), cv2.IMREAD_GRAYSCALE)
|
| 133 |
+
# Normalize 0/255 -> 0/1
|
| 134 |
+
mask = (mask / 255.0).astype(np.float32)
|
| 135 |
+
|
| 136 |
+
# Apply synchronized augmentations
|
| 137 |
+
if isinstance(self.transform, A.ReplayCompose):
|
| 138 |
+
transformed = self.transform(image=img_a, image_b=img_b, mask=mask)
|
| 139 |
+
img_a = transformed["image"]
|
| 140 |
+
img_b = transformed["image_b"]
|
| 141 |
+
mask = transformed["mask"]
|
| 142 |
+
else:
|
| 143 |
+
transformed = self.transform(image=img_a, image_b=img_b)
|
| 144 |
+
img_a = transformed["image"]
|
| 145 |
+
img_b = transformed["image_b"]
|
| 146 |
+
# Normalize only applied to images, mask stays as-is
|
| 147 |
+
|
| 148 |
+
# HWC -> CHW for images, add channel dim for mask
|
| 149 |
+
img_a = torch.from_numpy(img_a).permute(2, 0, 1).float()
|
| 150 |
+
img_b = torch.from_numpy(img_b).permute(2, 0, 1).float()
|
| 151 |
+
mask = torch.from_numpy(mask).unsqueeze(0).float()
|
| 152 |
+
|
| 153 |
+
return {"A": img_a, "B": img_b, "mask": mask, "filename": fname}
|
data/download.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Download and preprocess change detection datasets.
|
| 2 |
+
|
| 3 |
+
Supports LEVIR-CD and WHU-CD datasets. Downloads raw data, crops 1024x1024
|
| 4 |
+
images into 256x256 non-overlapping patches, and organizes into train/val/test
|
| 5 |
+
splits.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python data/download.py --dataset levir-cd --raw_dir ./raw_data --out_dir ./processed_data
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
|
| 16 |
+
import cv2
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def download_levir_cd(raw_dir: Path) -> None:
|
| 23 |
+
"""Download the LEVIR-CD dataset.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
raw_dir: Directory to save the raw downloaded files.
|
| 27 |
+
"""
|
| 28 |
+
# TODO: Implement download via gdown or direct URL
|
| 29 |
+
raise NotImplementedError("LEVIR-CD download not yet implemented")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def download_whu_cd(raw_dir: Path) -> None:
|
| 33 |
+
"""Download the WHU-CD dataset.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
raw_dir: Directory to save the raw downloaded files.
|
| 37 |
+
"""
|
| 38 |
+
# TODO: Implement download
|
| 39 |
+
raise NotImplementedError("WHU-CD download not yet implemented")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def crop_to_patches(
|
| 43 |
+
image: np.ndarray,
|
| 44 |
+
patch_size: int = 256,
|
| 45 |
+
) -> list[np.ndarray]:
|
| 46 |
+
"""Crop an image into non-overlapping patches.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
image: Input image of shape (H, W) or (H, W, C).
|
| 50 |
+
patch_size: Size of each square patch.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
List of cropped patches.
|
| 54 |
+
"""
|
| 55 |
+
h, w = image.shape[:2]
|
| 56 |
+
patches = []
|
| 57 |
+
for y in range(0, h - patch_size + 1, patch_size):
|
| 58 |
+
for x in range(0, w - patch_size + 1, patch_size):
|
| 59 |
+
patch = image[y : y + patch_size, x : x + patch_size]
|
| 60 |
+
patches.append(patch)
|
| 61 |
+
return patches
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def process_split(
|
| 65 |
+
raw_dir: Path,
|
| 66 |
+
out_dir: Path,
|
| 67 |
+
split: str,
|
| 68 |
+
patch_size: int = 256,
|
| 69 |
+
) -> int:
|
| 70 |
+
"""Process a single dataset split (train/val/test).
|
| 71 |
+
|
| 72 |
+
Reads image pairs and masks from raw_dir, crops into patches, and
|
| 73 |
+
saves to out_dir.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
raw_dir: Root directory of the raw dataset.
|
| 77 |
+
out_dir: Output directory for processed patches.
|
| 78 |
+
split: One of 'train', 'val', 'test'.
|
| 79 |
+
patch_size: Size of each square patch.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Number of patch triplets generated.
|
| 83 |
+
"""
|
| 84 |
+
# TODO: Implement processing pipeline
|
| 85 |
+
raise NotImplementedError("Split processing not yet implemented")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def preprocess_dataset(
|
| 89 |
+
dataset: str,
|
| 90 |
+
raw_dir: Path,
|
| 91 |
+
out_dir: Path,
|
| 92 |
+
patch_size: int = 256,
|
| 93 |
+
) -> None:
|
| 94 |
+
"""Run full preprocessing pipeline for a dataset.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
dataset: Dataset name ('levir-cd' or 'whu-cd').
|
| 98 |
+
raw_dir: Directory containing raw downloaded data.
|
| 99 |
+
out_dir: Output directory for processed patches.
|
| 100 |
+
patch_size: Size of each square patch.
|
| 101 |
+
"""
|
| 102 |
+
logger.info("Preprocessing %s: %s -> %s", dataset, raw_dir, out_dir)
|
| 103 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 104 |
+
|
| 105 |
+
for split in ["train", "val", "test"]:
|
| 106 |
+
count = process_split(raw_dir, out_dir, split, patch_size)
|
| 107 |
+
logger.info(" %s: %d patch triplets", split, count)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def main() -> None:
|
| 111 |
+
"""CLI entry point for dataset download and preprocessing."""
|
| 112 |
+
parser = argparse.ArgumentParser(description="Download and preprocess change detection datasets")
|
| 113 |
+
parser.add_argument("--dataset", type=str, default="levir-cd", choices=["levir-cd", "whu-cd"])
|
| 114 |
+
parser.add_argument("--raw_dir", type=Path, default=Path("./raw_data"))
|
| 115 |
+
parser.add_argument("--out_dir", type=Path, default=Path("./processed_data"))
|
| 116 |
+
parser.add_argument("--patch_size", type=int, default=256)
|
| 117 |
+
parser.add_argument("--skip_download", action="store_true", help="Skip download, only preprocess")
|
| 118 |
+
args = parser.parse_args()
|
| 119 |
+
|
| 120 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 121 |
+
|
| 122 |
+
if not args.skip_download:
|
| 123 |
+
if args.dataset == "levir-cd":
|
| 124 |
+
download_levir_cd(args.raw_dir)
|
| 125 |
+
elif args.dataset == "whu-cd":
|
| 126 |
+
download_whu_cd(args.raw_dir)
|
| 127 |
+
|
| 128 |
+
preprocess_dataset(args.dataset, args.raw_dir, args.out_dir, args.patch_size)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
evaluate.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation script for change detection models.
|
| 2 |
+
|
| 3 |
+
Runs a trained model on the test set, computes all metrics, and generates
|
| 4 |
+
visualization outputs.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python evaluate.py --config configs/config.yaml --checkpoint checkpoints/unet_pp_best.pth
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any, Dict
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import yaml
|
| 20 |
+
|
| 21 |
+
from data.dataset import ChangeDetectionDataset
|
| 22 |
+
from models import get_model
|
| 23 |
+
from utils.metrics import ConfusionMatrix
|
| 24 |
+
from utils.visualization import plot_prediction
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def evaluate(
|
| 30 |
+
model: nn.Module,
|
| 31 |
+
loader: DataLoader,
|
| 32 |
+
device: torch.device,
|
| 33 |
+
threshold: float = 0.5,
|
| 34 |
+
output_dir: Path = Path("./outputs"),
|
| 35 |
+
max_vis: int = 20,
|
| 36 |
+
) -> Dict[str, float]:
|
| 37 |
+
"""Evaluate model on the full test set.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model: Trained change detection model.
|
| 41 |
+
loader: Test DataLoader.
|
| 42 |
+
device: Target device.
|
| 43 |
+
threshold: Binarization threshold for predictions.
|
| 44 |
+
output_dir: Directory to save visualization outputs.
|
| 45 |
+
max_vis: Maximum number of sample predictions to save.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Dict of metric name -> value.
|
| 49 |
+
"""
|
| 50 |
+
model.eval()
|
| 51 |
+
cm = ConfusionMatrix()
|
| 52 |
+
vis_dir = output_dir / "visualizations"
|
| 53 |
+
vis_dir.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
vis_count = 0
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
for batch in tqdm(loader, desc="Evaluating"):
|
| 58 |
+
img_a = batch["A"].to(device)
|
| 59 |
+
img_b = batch["B"].to(device)
|
| 60 |
+
mask = batch["mask"].to(device)
|
| 61 |
+
|
| 62 |
+
logits = model(img_a, img_b)
|
| 63 |
+
preds = (torch.sigmoid(logits) > threshold).float()
|
| 64 |
+
cm.update(preds, mask)
|
| 65 |
+
|
| 66 |
+
# Save sample visualizations
|
| 67 |
+
if vis_count < max_vis:
|
| 68 |
+
for i in range(min(img_a.size(0), max_vis - vis_count)):
|
| 69 |
+
plot_prediction(
|
| 70 |
+
img_a[i], img_b[i], mask[i], preds[i],
|
| 71 |
+
save_path=vis_dir / f"sample_{vis_count:04d}.png",
|
| 72 |
+
)
|
| 73 |
+
vis_count += 1
|
| 74 |
+
|
| 75 |
+
metrics = cm.compute()
|
| 76 |
+
return metrics
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> None:
|
| 80 |
+
"""Main evaluation entry point."""
|
| 81 |
+
parser = argparse.ArgumentParser(description="Evaluate change detection model")
|
| 82 |
+
parser.add_argument("--config", type=Path, default=Path("configs/config.yaml"))
|
| 83 |
+
parser.add_argument("--checkpoint", type=Path, required=True)
|
| 84 |
+
parser.add_argument("--model", type=str, default=None, help="Override model name")
|
| 85 |
+
parser.add_argument("--threshold", type=float, default=None)
|
| 86 |
+
args = parser.parse_args()
|
| 87 |
+
|
| 88 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 89 |
+
|
| 90 |
+
with open(args.config, "r") as f:
|
| 91 |
+
config = yaml.safe_load(f)
|
| 92 |
+
|
| 93 |
+
model_name = args.model or config["model"]["name"]
|
| 94 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 95 |
+
threshold = args.threshold or config.get("evaluation", {}).get("threshold", 0.5)
|
| 96 |
+
|
| 97 |
+
# Resolve paths
|
| 98 |
+
colab = config.get("colab", {})
|
| 99 |
+
if colab.get("enabled", False):
|
| 100 |
+
data_dir = Path(colab["data_dir"])
|
| 101 |
+
output_dir = Path(colab["output_dir"])
|
| 102 |
+
else:
|
| 103 |
+
data_dir = Path(config["paths"]["processed_data"])
|
| 104 |
+
output_dir = Path(config["paths"]["output_dir"])
|
| 105 |
+
|
| 106 |
+
# Model
|
| 107 |
+
model = get_model(model_name, config).to(device)
|
| 108 |
+
ckpt = torch.load(args.checkpoint, map_location=device)
|
| 109 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 110 |
+
logger.info("Loaded checkpoint: %s (epoch %d, F1 %.4f)",
|
| 111 |
+
args.checkpoint, ckpt.get("epoch", -1), ckpt.get("best_f1", -1))
|
| 112 |
+
|
| 113 |
+
# Test data
|
| 114 |
+
ds_cfg = config.get("dataset", {})
|
| 115 |
+
test_ds = ChangeDetectionDataset(data_dir / "test", split="test", config=config)
|
| 116 |
+
test_loader = DataLoader(
|
| 117 |
+
test_ds, batch_size=8, shuffle=False,
|
| 118 |
+
num_workers=ds_cfg.get("num_workers", 4),
|
| 119 |
+
pin_memory=ds_cfg.get("pin_memory", True),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Evaluate
|
| 123 |
+
metrics = evaluate(model, test_loader, device, threshold, output_dir)
|
| 124 |
+
|
| 125 |
+
# Print results
|
| 126 |
+
logger.info("=" * 50)
|
| 127 |
+
logger.info("TEST SET RESULTS — %s", model_name)
|
| 128 |
+
logger.info("=" * 50)
|
| 129 |
+
for name, value in metrics.items():
|
| 130 |
+
logger.info(" %-12s: %.4f", name.upper(), value)
|
| 131 |
+
logger.info("=" * 50)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
main()
|
inference.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run inference on arbitrary before/after image pairs.
|
| 2 |
+
|
| 3 |
+
Loads a trained change detection model and produces binary change masks
|
| 4 |
+
for new satellite image pairs.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python inference.py --before path/to/before.png --after path/to/after.png \
|
| 8 |
+
--model changeformer --checkpoint checkpoints/changeformer_best.pth
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Dict, Tuple
|
| 15 |
+
|
| 16 |
+
import cv2
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import yaml
|
| 22 |
+
|
| 23 |
+
from data.dataset import IMAGENET_MEAN, IMAGENET_STD
|
| 24 |
+
from models import get_model
|
| 25 |
+
from utils.visualization import overlay_changes, plot_prediction
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def preprocess_image(
|
| 31 |
+
image_path: Path,
|
| 32 |
+
patch_size: int = 256,
|
| 33 |
+
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
| 34 |
+
"""Load and preprocess a single image for inference.
|
| 35 |
+
|
| 36 |
+
Reads the image, pads to a multiple of patch_size, and applies
|
| 37 |
+
ImageNet normalization.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
image_path: Path to the input image.
|
| 41 |
+
patch_size: Patch size the model expects.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tuple of (preprocessed tensor [1, 3, H, W], original (H, W)).
|
| 45 |
+
"""
|
| 46 |
+
img = cv2.imread(str(image_path), cv2.IMREAD_COLOR)
|
| 47 |
+
if img is None:
|
| 48 |
+
raise FileNotFoundError(f"Could not read image: {image_path}")
|
| 49 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 50 |
+
orig_h, orig_w = img.shape[:2]
|
| 51 |
+
|
| 52 |
+
# Pad to multiple of patch_size
|
| 53 |
+
pad_h = (patch_size - orig_h % patch_size) % patch_size
|
| 54 |
+
pad_w = (patch_size - orig_w % patch_size) % patch_size
|
| 55 |
+
if pad_h > 0 or pad_w > 0:
|
| 56 |
+
img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
|
| 57 |
+
|
| 58 |
+
# Normalize
|
| 59 |
+
img = img.astype(np.float32) / 255.0
|
| 60 |
+
mean = np.array(IMAGENET_MEAN, dtype=np.float32)
|
| 61 |
+
std = np.array(IMAGENET_STD, dtype=np.float32)
|
| 62 |
+
img = (img - mean) / std
|
| 63 |
+
|
| 64 |
+
# HWC -> CHW, add batch dim
|
| 65 |
+
tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()
|
| 66 |
+
return tensor, (orig_h, orig_w)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def sliding_window_inference(
|
| 70 |
+
model: nn.Module,
|
| 71 |
+
img_a: torch.Tensor,
|
| 72 |
+
img_b: torch.Tensor,
|
| 73 |
+
patch_size: int = 256,
|
| 74 |
+
device: torch.device = torch.device("cpu"),
|
| 75 |
+
) -> torch.Tensor:
|
| 76 |
+
"""Run inference using sliding window for large images.
|
| 77 |
+
|
| 78 |
+
Splits images into non-overlapping patches, runs model on each,
|
| 79 |
+
and stitches results back together.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
model: Trained change detection model.
|
| 83 |
+
img_a: Before image tensor [1, 3, H, W].
|
| 84 |
+
img_b: After image tensor [1, 3, H, W].
|
| 85 |
+
patch_size: Size of each patch.
|
| 86 |
+
device: Inference device.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Probability map [1, 1, H, W] (after sigmoid).
|
| 90 |
+
"""
|
| 91 |
+
_, _, h, w = img_a.shape
|
| 92 |
+
output = torch.zeros(1, 1, h, w, device="cpu")
|
| 93 |
+
|
| 94 |
+
model.eval()
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
for y in range(0, h, patch_size):
|
| 97 |
+
for x in range(0, w, patch_size):
|
| 98 |
+
patch_a = img_a[:, :, y:y + patch_size, x:x + patch_size].to(device)
|
| 99 |
+
patch_b = img_b[:, :, y:y + patch_size, x:x + patch_size].to(device)
|
| 100 |
+
|
| 101 |
+
logits = model(patch_a, patch_b)
|
| 102 |
+
probs = torch.sigmoid(logits).cpu()
|
| 103 |
+
output[:, :, y:y + patch_size, x:x + patch_size] = probs
|
| 104 |
+
|
| 105 |
+
return output
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def save_change_mask(
|
| 109 |
+
mask: np.ndarray,
|
| 110 |
+
save_path: Path,
|
| 111 |
+
threshold: float = 0.5,
|
| 112 |
+
) -> None:
|
| 113 |
+
"""Save binary change mask as an image.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
mask: Probability map [H, W] with values in [0, 1].
|
| 117 |
+
save_path: Output file path.
|
| 118 |
+
threshold: Binarization threshold.
|
| 119 |
+
"""
|
| 120 |
+
binary = (mask > threshold).astype(np.uint8) * 255
|
| 121 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 122 |
+
cv2.imwrite(str(save_path), binary)
|
| 123 |
+
logger.info("Saved change mask: %s", save_path)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def main() -> None:
|
| 127 |
+
"""Main inference entry point."""
|
| 128 |
+
parser = argparse.ArgumentParser(description="Run change detection inference")
|
| 129 |
+
parser.add_argument("--before", type=Path, required=True, help="Path to before image")
|
| 130 |
+
parser.add_argument("--after", type=Path, required=True, help="Path to after image")
|
| 131 |
+
parser.add_argument("--model", type=str, default=None, help="Model name")
|
| 132 |
+
parser.add_argument("--checkpoint", type=Path, required=True, help="Path to model checkpoint")
|
| 133 |
+
parser.add_argument("--config", type=Path, default=Path("configs/config.yaml"))
|
| 134 |
+
parser.add_argument("--output", type=Path, default=Path("outputs/inference"))
|
| 135 |
+
parser.add_argument("--threshold", type=float, default=0.5)
|
| 136 |
+
args = parser.parse_args()
|
| 137 |
+
|
| 138 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 139 |
+
|
| 140 |
+
# Load config
|
| 141 |
+
with open(args.config, "r") as f:
|
| 142 |
+
config = yaml.safe_load(f)
|
| 143 |
+
|
| 144 |
+
model_name = args.model or config["model"]["name"]
|
| 145 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 146 |
+
patch_size = config.get("dataset", {}).get("patch_size", 256)
|
| 147 |
+
|
| 148 |
+
# Load model
|
| 149 |
+
model = get_model(model_name, config).to(device)
|
| 150 |
+
ckpt = torch.load(args.checkpoint, map_location=device)
|
| 151 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 152 |
+
logger.info("Loaded model '%s' from %s", model_name, args.checkpoint)
|
| 153 |
+
|
| 154 |
+
# Preprocess images
|
| 155 |
+
img_a, (orig_h, orig_w) = preprocess_image(args.before, patch_size)
|
| 156 |
+
img_b, _ = preprocess_image(args.after, patch_size)
|
| 157 |
+
|
| 158 |
+
# Run inference
|
| 159 |
+
prob_map = sliding_window_inference(model, img_a, img_b, patch_size, device)
|
| 160 |
+
|
| 161 |
+
# Crop back to original size and save
|
| 162 |
+
prob_map = prob_map[:, :, :orig_h, :orig_w]
|
| 163 |
+
mask_np = prob_map.squeeze().numpy()
|
| 164 |
+
|
| 165 |
+
args.output.mkdir(parents=True, exist_ok=True)
|
| 166 |
+
save_change_mask(mask_np, args.output / "change_mask.png", args.threshold)
|
| 167 |
+
|
| 168 |
+
# Save overlay visualization
|
| 169 |
+
overlay = overlay_changes(img_b.squeeze()[:, :orig_h, :orig_w], prob_map.squeeze(0))
|
| 170 |
+
overlay_uint8 = (overlay * 255).astype(np.uint8)
|
| 171 |
+
cv2.imwrite(str(args.output / "overlay.png"), cv2.cvtColor(overlay_uint8, cv2.COLOR_RGB2BGR))
|
| 172 |
+
logger.info("Saved overlay: %s", args.output / "overlay.png")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
if __name__ == "__main__":
|
| 176 |
+
main()
|
models/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model factory for change detection models.
|
| 2 |
+
|
| 3 |
+
Provides a unified interface to instantiate any supported model by name.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Any, Dict
|
| 7 |
+
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from .changeformer import ChangeFormer
|
| 11 |
+
from .siamese_cnn import SiameseCNN
|
| 12 |
+
from .unet_pp import UNetPPChangeDetection
|
| 13 |
+
|
| 14 |
+
_MODEL_REGISTRY: Dict[str, type] = {
|
| 15 |
+
"siamese_cnn": SiameseCNN,
|
| 16 |
+
"unet_pp": UNetPPChangeDetection,
|
| 17 |
+
"changeformer": ChangeFormer,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_model(model_name: str, config: Dict[str, Any]) -> nn.Module:
|
| 22 |
+
"""Instantiate a change detection model by name.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model_name: One of 'siamese_cnn', 'unet_pp', 'changeformer'.
|
| 26 |
+
config: Full config dict; model-specific section is extracted internally.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Initialized model (nn.Module).
|
| 30 |
+
|
| 31 |
+
Raises:
|
| 32 |
+
ValueError: If model_name is not recognized.
|
| 33 |
+
"""
|
| 34 |
+
if model_name not in _MODEL_REGISTRY:
|
| 35 |
+
raise ValueError(f"Unknown model '{model_name}'. Choose from: {list(_MODEL_REGISTRY.keys())}")
|
| 36 |
+
|
| 37 |
+
model_cls = _MODEL_REGISTRY[model_name]
|
| 38 |
+
model_config = config.get(model_name, {})
|
| 39 |
+
return model_cls(**model_config)
|
models/changeformer.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ChangeFormer — Transformer-based change detection model.
|
| 2 |
+
|
| 3 |
+
Implements a hierarchical vision transformer (MiT-B1 style) with shared-weight
|
| 4 |
+
Siamese encoder and MLP decoder for change detection. Based on:
|
| 5 |
+
"A Transformer-Based Siamese Network for Change Detection" (arXiv:2201.01293).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class OverlapPatchEmbed(nn.Module):
|
| 17 |
+
"""Overlapping patch embedding for hierarchical feature extraction.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
in_channels: Number of input channels.
|
| 21 |
+
embed_dim: Embedding dimension.
|
| 22 |
+
patch_size: Patch size for convolution.
|
| 23 |
+
stride: Stride for convolution.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
in_channels: int = 3,
|
| 29 |
+
embed_dim: int = 64,
|
| 30 |
+
patch_size: int = 7,
|
| 31 |
+
stride: int = 4,
|
| 32 |
+
) -> None:
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.proj = nn.Conv2d(
|
| 35 |
+
in_channels, embed_dim,
|
| 36 |
+
kernel_size=patch_size, stride=stride,
|
| 37 |
+
padding=patch_size // 2,
|
| 38 |
+
)
|
| 39 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
| 42 |
+
"""Forward pass.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
x: Input tensor [B, C, H, W].
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (tokens [B, N, D], height, width).
|
| 49 |
+
"""
|
| 50 |
+
x = self.proj(x)
|
| 51 |
+
_, _, h, w = x.shape
|
| 52 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
| 53 |
+
x = self.norm(x)
|
| 54 |
+
return x, h, w
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class EfficientSelfAttention(nn.Module):
|
| 58 |
+
"""Efficient self-attention with spatial reduction.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dim: Input dimension.
|
| 62 |
+
num_heads: Number of attention heads.
|
| 63 |
+
sr_ratio: Spatial reduction ratio.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, dim: int, num_heads: int = 1, sr_ratio: int = 8) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.num_heads = num_heads
|
| 69 |
+
self.head_dim = dim // num_heads
|
| 70 |
+
self.scale = self.head_dim ** -0.5
|
| 71 |
+
|
| 72 |
+
self.q = nn.Linear(dim, dim)
|
| 73 |
+
self.kv = nn.Linear(dim, dim * 2)
|
| 74 |
+
self.proj = nn.Linear(dim, dim)
|
| 75 |
+
|
| 76 |
+
# Spatial reduction
|
| 77 |
+
self.sr_ratio = sr_ratio
|
| 78 |
+
if sr_ratio > 1:
|
| 79 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
| 80 |
+
self.sr_norm = nn.LayerNorm(dim)
|
| 81 |
+
|
| 82 |
+
def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
| 83 |
+
"""Forward pass.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
x: Input tokens [B, N, C].
|
| 87 |
+
h: Feature map height.
|
| 88 |
+
w: Feature map width.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Output tokens [B, N, C].
|
| 92 |
+
"""
|
| 93 |
+
b, n, c = x.shape
|
| 94 |
+
q = self.q(x).reshape(b, n, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 95 |
+
|
| 96 |
+
if self.sr_ratio > 1:
|
| 97 |
+
x_ = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| 98 |
+
x_ = self.sr(x_)
|
| 99 |
+
x_ = rearrange(x_, "b c h w -> b (h w) c")
|
| 100 |
+
x_ = self.sr_norm(x_)
|
| 101 |
+
else:
|
| 102 |
+
x_ = x
|
| 103 |
+
|
| 104 |
+
kv = self.kv(x_).reshape(b, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 105 |
+
k, v = kv[0], kv[1]
|
| 106 |
+
|
| 107 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 108 |
+
attn = attn.softmax(dim=-1)
|
| 109 |
+
out = (attn @ v).transpose(1, 2).reshape(b, n, c)
|
| 110 |
+
out = self.proj(out)
|
| 111 |
+
return out
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class MixFFN(nn.Module):
|
| 115 |
+
"""Mix Feed-Forward Network with depthwise convolution.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
dim: Input/output dimension.
|
| 119 |
+
mlp_ratio: Expansion ratio for hidden dimension.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(self, dim: int, mlp_ratio: int = 4) -> None:
|
| 123 |
+
super().__init__()
|
| 124 |
+
hidden = dim * mlp_ratio
|
| 125 |
+
self.fc1 = nn.Linear(dim, hidden)
|
| 126 |
+
self.dwconv = nn.Conv2d(hidden, hidden, 3, 1, 1, groups=hidden)
|
| 127 |
+
self.fc2 = nn.Linear(hidden, dim)
|
| 128 |
+
self.act = nn.GELU()
|
| 129 |
+
|
| 130 |
+
def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
| 131 |
+
"""Forward pass.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
x: Input tokens [B, N, C].
|
| 135 |
+
h: Feature map height.
|
| 136 |
+
w: Feature map width.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Output tokens [B, N, C].
|
| 140 |
+
"""
|
| 141 |
+
x = self.fc1(x)
|
| 142 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| 143 |
+
x = self.act(self.dwconv(x))
|
| 144 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
| 145 |
+
x = self.fc2(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class TransformerBlock(nn.Module):
|
| 150 |
+
"""Single transformer block with efficient attention and MixFFN.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
dim: Feature dimension.
|
| 154 |
+
num_heads: Number of attention heads.
|
| 155 |
+
mlp_ratio: MLP expansion ratio.
|
| 156 |
+
sr_ratio: Spatial reduction ratio for attention.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
dim: int,
|
| 162 |
+
num_heads: int = 1,
|
| 163 |
+
mlp_ratio: int = 4,
|
| 164 |
+
sr_ratio: int = 8,
|
| 165 |
+
) -> None:
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 168 |
+
self.attn = EfficientSelfAttention(dim, num_heads, sr_ratio)
|
| 169 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 170 |
+
self.ffn = MixFFN(dim, mlp_ratio)
|
| 171 |
+
|
| 172 |
+
def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
| 173 |
+
"""Forward pass.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
x: Input tokens [B, N, C].
|
| 177 |
+
h: Feature map height.
|
| 178 |
+
w: Feature map width.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Output tokens [B, N, C].
|
| 182 |
+
"""
|
| 183 |
+
x = x + self.attn(self.norm1(x), h, w)
|
| 184 |
+
x = x + self.ffn(self.norm2(x), h, w)
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class MiTEncoder(nn.Module):
|
| 189 |
+
"""Mix Transformer (MiT) encoder — hierarchical vision transformer.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
embed_dims: Embedding dimensions at each stage.
|
| 193 |
+
num_heads: Number of attention heads at each stage.
|
| 194 |
+
mlp_ratios: MLP expansion ratios at each stage.
|
| 195 |
+
depths: Number of transformer blocks at each stage.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
embed_dims: List[int] = [64, 128, 320, 512],
|
| 201 |
+
num_heads: List[int] = [1, 2, 5, 8],
|
| 202 |
+
mlp_ratios: List[int] = [8, 8, 4, 4],
|
| 203 |
+
depths: List[int] = [2, 2, 2, 2],
|
| 204 |
+
) -> None:
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.num_stages = len(embed_dims)
|
| 207 |
+
|
| 208 |
+
sr_ratios = [8, 4, 2, 1]
|
| 209 |
+
patch_sizes = [7, 3, 3, 3]
|
| 210 |
+
strides = [4, 2, 2, 2]
|
| 211 |
+
|
| 212 |
+
self.patch_embeds = nn.ModuleList()
|
| 213 |
+
self.blocks = nn.ModuleList()
|
| 214 |
+
self.norms = nn.ModuleList()
|
| 215 |
+
|
| 216 |
+
for i in range(self.num_stages):
|
| 217 |
+
in_ch = 3 if i == 0 else embed_dims[i - 1]
|
| 218 |
+
self.patch_embeds.append(
|
| 219 |
+
OverlapPatchEmbed(in_ch, embed_dims[i], patch_sizes[i], strides[i])
|
| 220 |
+
)
|
| 221 |
+
self.blocks.append(
|
| 222 |
+
nn.ModuleList([
|
| 223 |
+
TransformerBlock(embed_dims[i], num_heads[i], mlp_ratios[i], sr_ratios[i])
|
| 224 |
+
for _ in range(depths[i])
|
| 225 |
+
])
|
| 226 |
+
)
|
| 227 |
+
self.norms.append(nn.LayerNorm(embed_dims[i]))
|
| 228 |
+
|
| 229 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 230 |
+
"""Extract hierarchical features.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
x: Input image [B, 3, H, W].
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
List of feature maps at each stage [B, C_i, H_i, W_i].
|
| 237 |
+
"""
|
| 238 |
+
features = []
|
| 239 |
+
for i in range(self.num_stages):
|
| 240 |
+
x, h, w = self.patch_embeds[i](x)
|
| 241 |
+
for blk in self.blocks[i]:
|
| 242 |
+
x = blk(x, h, w)
|
| 243 |
+
x = self.norms[i](x)
|
| 244 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| 245 |
+
features.append(x)
|
| 246 |
+
return features
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class MLPDecoder(nn.Module):
|
| 250 |
+
"""MLP-based decoder that fuses multi-scale difference features.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
embed_dims: Embedding dimensions from each encoder stage.
|
| 254 |
+
out_channels: Number of output channels (1 for binary change mask).
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
embed_dims: List[int] = [64, 128, 320, 512],
|
| 260 |
+
out_channels: int = 1,
|
| 261 |
+
) -> None:
|
| 262 |
+
super().__init__()
|
| 263 |
+
unified_dim = embed_dims[0]
|
| 264 |
+
|
| 265 |
+
self.linear_projections = nn.ModuleList([
|
| 266 |
+
nn.Conv2d(dim, unified_dim, kernel_size=1)
|
| 267 |
+
for dim in embed_dims
|
| 268 |
+
])
|
| 269 |
+
|
| 270 |
+
self.fuse = nn.Sequential(
|
| 271 |
+
nn.Conv2d(unified_dim * len(embed_dims), unified_dim, kernel_size=1),
|
| 272 |
+
nn.BatchNorm2d(unified_dim),
|
| 273 |
+
nn.ReLU(inplace=True),
|
| 274 |
+
)
|
| 275 |
+
self.head = nn.Conv2d(unified_dim, out_channels, kernel_size=1)
|
| 276 |
+
|
| 277 |
+
def forward(self, features: List[torch.Tensor], target_size: Tuple[int, int]) -> torch.Tensor:
|
| 278 |
+
"""Forward pass.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
features: List of difference feature maps.
|
| 282 |
+
target_size: (H, W) of the desired output.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
Logits [B, 1, H, W].
|
| 286 |
+
"""
|
| 287 |
+
projected = []
|
| 288 |
+
for i, (feat, proj) in enumerate(zip(features, self.linear_projections)):
|
| 289 |
+
p = proj(feat)
|
| 290 |
+
p = F.interpolate(p, size=target_size, mode="bilinear", align_corners=False)
|
| 291 |
+
projected.append(p)
|
| 292 |
+
|
| 293 |
+
fused = self.fuse(torch.cat(projected, dim=1))
|
| 294 |
+
out = self.head(fused)
|
| 295 |
+
return out
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class ChangeFormer(nn.Module):
|
| 299 |
+
"""ChangeFormer: Transformer-based Siamese network for change detection.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
embed_dims: Embedding dims at each hierarchical stage.
|
| 303 |
+
num_heads: Attention heads at each stage.
|
| 304 |
+
mlp_ratios: MLP expansion ratios at each stage.
|
| 305 |
+
depths: Transformer block counts at each stage.
|
| 306 |
+
pretrained_backbone: Whether to load pretrained MiT weights.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
def __init__(
|
| 310 |
+
self,
|
| 311 |
+
embed_dims: List[int] = [64, 128, 320, 512],
|
| 312 |
+
num_heads: List[int] = [1, 2, 5, 8],
|
| 313 |
+
mlp_ratios: List[int] = [8, 8, 4, 4],
|
| 314 |
+
depths: List[int] = [2, 2, 2, 2],
|
| 315 |
+
pretrained_backbone: bool = True,
|
| 316 |
+
) -> None:
|
| 317 |
+
super().__init__()
|
| 318 |
+
|
| 319 |
+
# Shared Siamese encoder
|
| 320 |
+
self.encoder = MiTEncoder(embed_dims, num_heads, mlp_ratios, depths)
|
| 321 |
+
|
| 322 |
+
# MLP decoder
|
| 323 |
+
self.decoder = MLPDecoder(embed_dims, out_channels=1)
|
| 324 |
+
|
| 325 |
+
# TODO: Load pretrained MiT-B1 weights if pretrained_backbone is True
|
| 326 |
+
|
| 327 |
+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 328 |
+
"""Forward pass.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
x1: Before image [B, 3, 256, 256].
|
| 332 |
+
x2: After image [B, 3, 256, 256].
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
Raw logits [B, 1, 256, 256].
|
| 336 |
+
"""
|
| 337 |
+
# Extract hierarchical features
|
| 338 |
+
feats_1 = self.encoder(x1)
|
| 339 |
+
feats_2 = self.encoder(x2)
|
| 340 |
+
|
| 341 |
+
# Compute difference at each scale
|
| 342 |
+
diff_feats = [torch.abs(f1 - f2) for f1, f2 in zip(feats_1, feats_2)]
|
| 343 |
+
|
| 344 |
+
# Decode to change mask
|
| 345 |
+
target_size = (x1.shape[2], x1.shape[3])
|
| 346 |
+
out = self.decoder(diff_feats, target_size)
|
| 347 |
+
return out
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if __name__ == "__main__":
|
| 351 |
+
# Quick sanity check
|
| 352 |
+
model = ChangeFormer(pretrained_backbone=False)
|
| 353 |
+
x1 = torch.randn(1, 3, 256, 256)
|
| 354 |
+
x2 = torch.randn(1, 3, 256, 256)
|
| 355 |
+
out = model(x1, x2)
|
| 356 |
+
print(f"Input: {x1.shape}, Output: {out.shape}")
|
| 357 |
+
assert out.shape == (1, 1, 256, 256), f"Unexpected shape: {out.shape}"
|
| 358 |
+
print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
|
models/siamese_cnn.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Siamese CNN baseline for change detection.
|
| 2 |
+
|
| 3 |
+
Architecture: Shared-weight ResNet18 backbone extracts features from both
|
| 4 |
+
images. Feature difference is decoded via transposed convolutions to produce
|
| 5 |
+
a binary change mask.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torchvision.models as models
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SiameseCNN(nn.Module):
|
| 14 |
+
"""Siamese CNN with shared ResNet18 encoder and transposed-conv decoder.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
backbone: Backbone architecture name (default: 'resnet18').
|
| 18 |
+
pretrained: Whether to use ImageNet-pretrained weights.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, backbone: str = "resnet18", pretrained: bool = True) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
# Shared encoder
|
| 25 |
+
resnet = getattr(models, backbone)(
|
| 26 |
+
weights=models.ResNet18_Weights.DEFAULT if pretrained else None
|
| 27 |
+
)
|
| 28 |
+
# Remove avgpool and fc — keep feature extraction layers
|
| 29 |
+
self.encoder = nn.Sequential(
|
| 30 |
+
resnet.conv1,
|
| 31 |
+
resnet.bn1,
|
| 32 |
+
resnet.relu,
|
| 33 |
+
resnet.maxpool,
|
| 34 |
+
resnet.layer1, # 64 channels
|
| 35 |
+
resnet.layer2, # 128 channels
|
| 36 |
+
resnet.layer3, # 256 channels
|
| 37 |
+
resnet.layer4, # 512 channels
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Decoder: upsample difference features back to input resolution
|
| 41 |
+
self.decoder = nn.Sequential(
|
| 42 |
+
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
|
| 43 |
+
nn.BatchNorm2d(256),
|
| 44 |
+
nn.ReLU(inplace=True),
|
| 45 |
+
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
|
| 46 |
+
nn.BatchNorm2d(128),
|
| 47 |
+
nn.ReLU(inplace=True),
|
| 48 |
+
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
|
| 49 |
+
nn.BatchNorm2d(64),
|
| 50 |
+
nn.ReLU(inplace=True),
|
| 51 |
+
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
|
| 52 |
+
nn.BatchNorm2d(32),
|
| 53 |
+
nn.ReLU(inplace=True),
|
| 54 |
+
nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
"""Forward pass.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
x1: Before image [B, 3, 256, 256].
|
| 62 |
+
x2: After image [B, 3, 256, 256].
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Raw logits [B, 1, 256, 256].
|
| 66 |
+
"""
|
| 67 |
+
f1 = self.encoder(x1)
|
| 68 |
+
f2 = self.encoder(x2)
|
| 69 |
+
|
| 70 |
+
# Feature difference
|
| 71 |
+
diff = torch.abs(f1 - f2)
|
| 72 |
+
|
| 73 |
+
# Decode to change mask
|
| 74 |
+
out = self.decoder(diff)
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
# Quick sanity check
|
| 80 |
+
model = SiameseCNN(pretrained=False)
|
| 81 |
+
x1 = torch.randn(2, 3, 256, 256)
|
| 82 |
+
x2 = torch.randn(2, 3, 256, 256)
|
| 83 |
+
out = model(x1, x2)
|
| 84 |
+
print(f"Input: {x1.shape}, Output: {out.shape}")
|
| 85 |
+
assert out.shape == (2, 1, 256, 256), f"Unexpected shape: {out.shape}"
|
models/unet_pp.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""UNet++ (Nested U-Net) for change detection.
|
| 2 |
+
|
| 3 |
+
Uses a shared ResNet34 encoder from segmentation-models-pytorch. Features from
|
| 4 |
+
both temporal images are differenced and decoded through nested skip connections.
|
| 5 |
+
Optionally supports deep supervision.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import segmentation_models_pytorch as smp
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class UNetPPChangeDetection(nn.Module):
|
| 14 |
+
"""UNet++ adapted for bitemporal change detection.
|
| 15 |
+
|
| 16 |
+
A shared encoder processes both images. The absolute difference of
|
| 17 |
+
encoder features is fed into the UNet++ decoder.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
encoder_name: Encoder backbone (default: 'resnet34').
|
| 21 |
+
pretrained: Use ImageNet-pretrained encoder weights.
|
| 22 |
+
deep_supervision: Enable deep supervision outputs.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
encoder_name: str = "resnet34",
|
| 28 |
+
pretrained: bool = True,
|
| 29 |
+
deep_supervision: bool = False,
|
| 30 |
+
) -> None:
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.deep_supervision = deep_supervision
|
| 33 |
+
|
| 34 |
+
# Shared encoder via SMP
|
| 35 |
+
encoder_weights = "imagenet" if pretrained else None
|
| 36 |
+
self.base_model = smp.UnetPlusPlus(
|
| 37 |
+
encoder_name=encoder_name,
|
| 38 |
+
encoder_weights=encoder_weights,
|
| 39 |
+
in_channels=3,
|
| 40 |
+
classes=1,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# We'll use the encoder and decoder separately
|
| 44 |
+
self.encoder = self.base_model.encoder
|
| 45 |
+
self.decoder = self.base_model.decoder
|
| 46 |
+
self.segmentation_head = self.base_model.segmentation_head
|
| 47 |
+
|
| 48 |
+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
"""Forward pass.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
x1: Before image [B, 3, 256, 256].
|
| 53 |
+
x2: After image [B, 3, 256, 256].
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Raw logits [B, 1, 256, 256].
|
| 57 |
+
"""
|
| 58 |
+
# Extract multi-scale features from both images
|
| 59 |
+
features_1 = self.encoder(x1)
|
| 60 |
+
features_2 = self.encoder(x2)
|
| 61 |
+
|
| 62 |
+
# Compute absolute difference at each scale
|
| 63 |
+
diff_features = [torch.abs(f1 - f2) for f1, f2 in zip(features_1, features_2)]
|
| 64 |
+
|
| 65 |
+
# Decode
|
| 66 |
+
decoder_output = self.decoder(*diff_features)
|
| 67 |
+
out = self.segmentation_head(decoder_output)
|
| 68 |
+
return out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
# Quick sanity check
|
| 73 |
+
model = UNetPPChangeDetection(pretrained=False)
|
| 74 |
+
x1 = torch.randn(2, 3, 256, 256)
|
| 75 |
+
x2 = torch.randn(2, 3, 256, 256)
|
| 76 |
+
out = model(x1, x2)
|
| 77 |
+
print(f"Input: {x1.shape}, Output: {out.shape}")
|
| 78 |
+
assert out.shape == (2, 1, 256, 256), f"Unexpected shape: {out.shape}"
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.1.2
|
| 2 |
+
torchvision==0.16.2
|
| 3 |
+
segmentation-models-pytorch==0.3.3
|
| 4 |
+
timm==0.9.12
|
| 5 |
+
einops==0.7.0
|
| 6 |
+
albumentations==1.3.1
|
| 7 |
+
opencv-python-headless==4.9.0.80
|
| 8 |
+
scikit-learn==1.4.0
|
| 9 |
+
matplotlib==3.8.2
|
| 10 |
+
numpy==1.26.3
|
| 11 |
+
Pillow==10.2.0
|
| 12 |
+
PyYAML==6.0.1
|
| 13 |
+
tqdm==4.66.1
|
| 14 |
+
tensorboard==2.15.1
|
| 15 |
+
gradio==4.14.0
|
| 16 |
+
gdown==5.1.0
|
setup_colab.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Google Colab setup script.
|
| 2 |
+
|
| 3 |
+
Handles Drive mounting, GPU verification, dependency installation,
|
| 4 |
+
and path configuration. Run this at the start of every Colab session.
|
| 5 |
+
|
| 6 |
+
Usage (in Colab cell):
|
| 7 |
+
!python setup_colab.py
|
| 8 |
+
# Or import and call:
|
| 9 |
+
from setup_colab import setup
|
| 10 |
+
setup()
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
import subprocess
|
| 16 |
+
import sys
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Dict, Optional
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def mount_drive() -> None:
|
| 24 |
+
"""Mount Google Drive at /content/drive.
|
| 25 |
+
|
| 26 |
+
Skips if not running in Colab or already mounted.
|
| 27 |
+
"""
|
| 28 |
+
if not is_colab():
|
| 29 |
+
logger.info("Not running in Colab — skipping Drive mount.")
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
if Path("/content/drive/MyDrive").exists():
|
| 33 |
+
logger.info("Google Drive already mounted.")
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
from google.colab import drive
|
| 37 |
+
drive.mount("/content/drive")
|
| 38 |
+
logger.info("Google Drive mounted successfully.")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_colab() -> bool:
|
| 42 |
+
"""Check if running inside Google Colab.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
True if running in Colab environment.
|
| 46 |
+
"""
|
| 47 |
+
try:
|
| 48 |
+
import google.colab # noqa: F401
|
| 49 |
+
return True
|
| 50 |
+
except ImportError:
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def check_gpu() -> Optional[str]:
|
| 55 |
+
"""Check GPU availability and print device info.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
GPU name string, or None if no GPU available.
|
| 59 |
+
"""
|
| 60 |
+
import torch
|
| 61 |
+
|
| 62 |
+
if not torch.cuda.is_available():
|
| 63 |
+
logger.warning("No GPU detected! Training will be very slow.")
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 67 |
+
vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
|
| 68 |
+
logger.info("GPU: %s (%.1f GB VRAM)", gpu_name, vram_gb)
|
| 69 |
+
return gpu_name
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def detect_gpu_type() -> str:
|
| 73 |
+
"""Detect GPU type for batch size selection.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
One of 'T4', 'V100', or 'default'.
|
| 77 |
+
"""
|
| 78 |
+
import torch
|
| 79 |
+
|
| 80 |
+
if not torch.cuda.is_available():
|
| 81 |
+
return "default"
|
| 82 |
+
|
| 83 |
+
name = torch.cuda.get_device_name(0).upper()
|
| 84 |
+
if "T4" in name:
|
| 85 |
+
return "T4"
|
| 86 |
+
elif "V100" in name:
|
| 87 |
+
return "V100"
|
| 88 |
+
return "default"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def install_requirements() -> None:
|
| 92 |
+
"""Install project dependencies from requirements.txt."""
|
| 93 |
+
req_path = Path("requirements.txt")
|
| 94 |
+
if not req_path.exists():
|
| 95 |
+
logger.warning("requirements.txt not found in %s", Path.cwd())
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
logger.info("Installing dependencies...")
|
| 99 |
+
subprocess.check_call([
|
| 100 |
+
sys.executable, "-m", "pip", "install", "-q", "-r", str(req_path)
|
| 101 |
+
])
|
| 102 |
+
logger.info("Dependencies installed.")
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def create_drive_dirs(drive_root: str = "/content/drive/MyDrive/change-detection") -> Dict[str, Path]:
|
| 106 |
+
"""Create project directories on Google Drive.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
drive_root: Root directory on Drive for this project.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Dict mapping directory names to their paths.
|
| 113 |
+
"""
|
| 114 |
+
dirs = {
|
| 115 |
+
"root": Path(drive_root),
|
| 116 |
+
"checkpoints": Path(drive_root) / "checkpoints",
|
| 117 |
+
"logs": Path(drive_root) / "logs",
|
| 118 |
+
"outputs": Path(drive_root) / "outputs",
|
| 119 |
+
"data": Path(drive_root) / "processed_data",
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
for name, path in dirs.items():
|
| 123 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
logger.info(" %s: %s", name, path)
|
| 125 |
+
|
| 126 |
+
return dirs
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def setup(
|
| 130 |
+
drive_root: str = "/content/drive/MyDrive/change-detection",
|
| 131 |
+
install_deps: bool = True,
|
| 132 |
+
) -> Dict[str, Path]:
|
| 133 |
+
"""Run full Colab setup.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
drive_root: Root directory on Google Drive.
|
| 137 |
+
install_deps: Whether to install pip dependencies.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Dict of project directory paths.
|
| 141 |
+
"""
|
| 142 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 143 |
+
|
| 144 |
+
logger.info("=" * 60)
|
| 145 |
+
logger.info("Military Base Change Detection — Colab Setup")
|
| 146 |
+
logger.info("=" * 60)
|
| 147 |
+
|
| 148 |
+
# 1. Mount Drive
|
| 149 |
+
mount_drive()
|
| 150 |
+
|
| 151 |
+
# 2. Check GPU
|
| 152 |
+
gpu_name = check_gpu()
|
| 153 |
+
gpu_type = detect_gpu_type()
|
| 154 |
+
logger.info("GPU type for batch sizing: %s", gpu_type)
|
| 155 |
+
|
| 156 |
+
# 3. Install dependencies
|
| 157 |
+
if install_deps:
|
| 158 |
+
install_requirements()
|
| 159 |
+
|
| 160 |
+
# 4. Create Drive directories
|
| 161 |
+
logger.info("Creating project directories on Drive...")
|
| 162 |
+
dirs = create_drive_dirs(drive_root)
|
| 163 |
+
|
| 164 |
+
logger.info("=" * 60)
|
| 165 |
+
logger.info("Setup complete! Ready to train.")
|
| 166 |
+
logger.info("=" * 60)
|
| 167 |
+
|
| 168 |
+
return dirs
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
setup()
|
train.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main training script for change detection models.
|
| 2 |
+
|
| 3 |
+
Supports AMP, gradient clipping, early stopping, checkpoint saving to Google
|
| 4 |
+
Drive, and resume from checkpoint after Colab disconnects.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python train.py --config configs/config.yaml --model unet_pp
|
| 8 |
+
python train.py --config configs/config.yaml --model changeformer --resume checkpoints/changeformer_last.pth
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import logging
|
| 13 |
+
import random
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Any, Dict, Tuple
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 21 |
+
from torch.optim import AdamW
|
| 22 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
import yaml
|
| 27 |
+
|
| 28 |
+
from data.dataset import ChangeDetectionDataset
|
| 29 |
+
from models import get_model
|
| 30 |
+
from utils.losses import get_loss
|
| 31 |
+
from utils.metrics import ConfusionMatrix
|
| 32 |
+
from utils.visualization import plot_prediction
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def set_seed(seed: int) -> None:
|
| 38 |
+
"""Set random seeds for reproducibility.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
seed: Random seed value.
|
| 42 |
+
"""
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
np.random.seed(seed)
|
| 45 |
+
torch.manual_seed(seed)
|
| 46 |
+
torch.cuda.manual_seed_all(seed)
|
| 47 |
+
torch.backends.cudnn.deterministic = True
|
| 48 |
+
torch.backends.cudnn.benchmark = False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def detect_gpu_type() -> str:
|
| 52 |
+
"""Detect the current GPU type for batch size selection.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
GPU type string ('T4', 'V100', or 'default').
|
| 56 |
+
"""
|
| 57 |
+
if not torch.cuda.is_available():
|
| 58 |
+
return "default"
|
| 59 |
+
name = torch.cuda.get_device_name(0).upper()
|
| 60 |
+
if "T4" in name:
|
| 61 |
+
return "T4"
|
| 62 |
+
elif "V100" in name:
|
| 63 |
+
return "V100"
|
| 64 |
+
return "default"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def get_batch_size(config: Dict[str, Any], model_name: str) -> int:
|
| 68 |
+
"""Get appropriate batch size based on GPU and model.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
config: Full config dict.
|
| 72 |
+
model_name: Model name string.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Batch size integer.
|
| 76 |
+
"""
|
| 77 |
+
gpu_type = detect_gpu_type()
|
| 78 |
+
batch_sizes = config.get("batch_sizes", {}).get(model_name, {})
|
| 79 |
+
return batch_sizes.get(gpu_type, batch_sizes.get("default", 4))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_paths(config: Dict[str, Any]) -> Dict[str, Path]:
|
| 83 |
+
"""Resolve paths based on whether running on Colab or locally.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
config: Full config dict.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Dict with keys: 'data', 'checkpoints', 'logs', 'outputs'.
|
| 90 |
+
"""
|
| 91 |
+
if config.get("colab", {}).get("enabled", False):
|
| 92 |
+
colab = config["colab"]
|
| 93 |
+
return {
|
| 94 |
+
"data": Path(colab["data_dir"]),
|
| 95 |
+
"checkpoints": Path(colab["checkpoint_dir"]),
|
| 96 |
+
"logs": Path(colab["log_dir"]),
|
| 97 |
+
"outputs": Path(colab["output_dir"]),
|
| 98 |
+
}
|
| 99 |
+
else:
|
| 100 |
+
paths = config.get("paths", {})
|
| 101 |
+
return {
|
| 102 |
+
"data": Path(paths.get("processed_data", "./processed_data")),
|
| 103 |
+
"checkpoints": Path(paths.get("checkpoint_dir", "./checkpoints")),
|
| 104 |
+
"logs": Path(paths.get("log_dir", "./logs")),
|
| 105 |
+
"outputs": Path(paths.get("output_dir", "./outputs")),
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def build_dataloaders(
|
| 110 |
+
config: Dict[str, Any],
|
| 111 |
+
data_dir: Path,
|
| 112 |
+
batch_size: int,
|
| 113 |
+
) -> Tuple[DataLoader, DataLoader]:
|
| 114 |
+
"""Create train and validation DataLoaders.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
config: Full config dict.
|
| 118 |
+
data_dir: Path to processed dataset root.
|
| 119 |
+
batch_size: Batch size.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Tuple of (train_loader, val_loader).
|
| 123 |
+
"""
|
| 124 |
+
ds_cfg = config.get("dataset", {})
|
| 125 |
+
num_workers = ds_cfg.get("num_workers", 4)
|
| 126 |
+
pin_memory = ds_cfg.get("pin_memory", True)
|
| 127 |
+
|
| 128 |
+
train_ds = ChangeDetectionDataset(data_dir / "train", split="train", config=config)
|
| 129 |
+
val_ds = ChangeDetectionDataset(data_dir / "val", split="val", config=config)
|
| 130 |
+
|
| 131 |
+
train_loader = DataLoader(
|
| 132 |
+
train_ds, batch_size=batch_size, shuffle=True,
|
| 133 |
+
num_workers=num_workers, pin_memory=pin_memory, drop_last=True,
|
| 134 |
+
)
|
| 135 |
+
val_loader = DataLoader(
|
| 136 |
+
val_ds, batch_size=batch_size, shuffle=False,
|
| 137 |
+
num_workers=num_workers, pin_memory=pin_memory,
|
| 138 |
+
)
|
| 139 |
+
return train_loader, val_loader
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def train_one_epoch(
|
| 143 |
+
model: nn.Module,
|
| 144 |
+
loader: DataLoader,
|
| 145 |
+
criterion: nn.Module,
|
| 146 |
+
optimizer: torch.optim.Optimizer,
|
| 147 |
+
scaler: GradScaler,
|
| 148 |
+
device: torch.device,
|
| 149 |
+
config: Dict[str, Any],
|
| 150 |
+
) -> Tuple[float, Dict[str, float]]:
|
| 151 |
+
"""Run one training epoch.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
model: The change detection model.
|
| 155 |
+
loader: Training DataLoader.
|
| 156 |
+
criterion: Loss function.
|
| 157 |
+
optimizer: Optimizer.
|
| 158 |
+
scaler: GradScaler for AMP.
|
| 159 |
+
device: Target device.
|
| 160 |
+
config: Full config dict.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Tuple of (average loss, metrics dict).
|
| 164 |
+
"""
|
| 165 |
+
model.train()
|
| 166 |
+
running_loss = 0.0
|
| 167 |
+
cm = ConfusionMatrix()
|
| 168 |
+
train_cfg = config.get("training", {})
|
| 169 |
+
accum_steps = train_cfg.get("gradient_accumulation_steps", 1)
|
| 170 |
+
grad_clip = train_cfg.get("grad_clip_max_norm", 1.0)
|
| 171 |
+
threshold = config.get("evaluation", {}).get("threshold", 0.5)
|
| 172 |
+
|
| 173 |
+
optimizer.zero_grad()
|
| 174 |
+
|
| 175 |
+
for step, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
|
| 176 |
+
img_a = batch["A"].to(device)
|
| 177 |
+
img_b = batch["B"].to(device)
|
| 178 |
+
mask = batch["mask"].to(device)
|
| 179 |
+
|
| 180 |
+
with autocast():
|
| 181 |
+
logits = model(img_a, img_b)
|
| 182 |
+
loss = criterion(logits, mask) / accum_steps
|
| 183 |
+
|
| 184 |
+
scaler.scale(loss).backward()
|
| 185 |
+
|
| 186 |
+
if (step + 1) % accum_steps == 0:
|
| 187 |
+
scaler.unscale_(optimizer)
|
| 188 |
+
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
| 189 |
+
scaler.step(optimizer)
|
| 190 |
+
scaler.update()
|
| 191 |
+
optimizer.zero_grad()
|
| 192 |
+
|
| 193 |
+
running_loss += loss.item() * accum_steps
|
| 194 |
+
|
| 195 |
+
# Metrics
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
preds = (torch.sigmoid(logits) > threshold).float()
|
| 198 |
+
cm.update(preds, mask)
|
| 199 |
+
|
| 200 |
+
avg_loss = running_loss / len(loader)
|
| 201 |
+
metrics = cm.compute()
|
| 202 |
+
return avg_loss, metrics
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@torch.no_grad()
|
| 206 |
+
def validate(
|
| 207 |
+
model: nn.Module,
|
| 208 |
+
loader: DataLoader,
|
| 209 |
+
criterion: nn.Module,
|
| 210 |
+
device: torch.device,
|
| 211 |
+
threshold: float = 0.5,
|
| 212 |
+
) -> Tuple[float, Dict[str, float]]:
|
| 213 |
+
"""Run validation.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
model: The change detection model.
|
| 217 |
+
loader: Validation DataLoader.
|
| 218 |
+
criterion: Loss function.
|
| 219 |
+
device: Target device.
|
| 220 |
+
threshold: Binarization threshold.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Tuple of (average loss, metrics dict).
|
| 224 |
+
"""
|
| 225 |
+
model.eval()
|
| 226 |
+
running_loss = 0.0
|
| 227 |
+
cm = ConfusionMatrix()
|
| 228 |
+
|
| 229 |
+
for batch in tqdm(loader, desc="Val", leave=False):
|
| 230 |
+
img_a = batch["A"].to(device)
|
| 231 |
+
img_b = batch["B"].to(device)
|
| 232 |
+
mask = batch["mask"].to(device)
|
| 233 |
+
|
| 234 |
+
logits = model(img_a, img_b)
|
| 235 |
+
loss = criterion(logits, mask)
|
| 236 |
+
running_loss += loss.item()
|
| 237 |
+
|
| 238 |
+
preds = (torch.sigmoid(logits) > threshold).float()
|
| 239 |
+
cm.update(preds, mask)
|
| 240 |
+
|
| 241 |
+
avg_loss = running_loss / len(loader)
|
| 242 |
+
metrics = cm.compute()
|
| 243 |
+
return avg_loss, metrics
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def save_checkpoint(
|
| 247 |
+
model: nn.Module,
|
| 248 |
+
optimizer: torch.optim.Optimizer,
|
| 249 |
+
scheduler: Any,
|
| 250 |
+
scaler: GradScaler,
|
| 251 |
+
epoch: int,
|
| 252 |
+
best_f1: float,
|
| 253 |
+
save_path: Path,
|
| 254 |
+
) -> None:
|
| 255 |
+
"""Save a training checkpoint.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
model: Model to save.
|
| 259 |
+
optimizer: Optimizer state.
|
| 260 |
+
scheduler: LR scheduler state.
|
| 261 |
+
scaler: GradScaler state.
|
| 262 |
+
epoch: Current epoch number.
|
| 263 |
+
best_f1: Best validation F1 so far.
|
| 264 |
+
save_path: Path to save the checkpoint.
|
| 265 |
+
"""
|
| 266 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 267 |
+
torch.save({
|
| 268 |
+
"epoch": epoch,
|
| 269 |
+
"model_state_dict": model.state_dict(),
|
| 270 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 271 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 272 |
+
"scaler_state_dict": scaler.state_dict(),
|
| 273 |
+
"best_f1": best_f1,
|
| 274 |
+
}, save_path)
|
| 275 |
+
logger.info("Saved checkpoint: %s", save_path)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def load_checkpoint(
|
| 279 |
+
path: Path,
|
| 280 |
+
model: nn.Module,
|
| 281 |
+
optimizer: torch.optim.Optimizer,
|
| 282 |
+
scheduler: Any,
|
| 283 |
+
scaler: GradScaler,
|
| 284 |
+
device: torch.device,
|
| 285 |
+
) -> Tuple[int, float]:
|
| 286 |
+
"""Load a training checkpoint for resume.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
path: Path to the checkpoint file.
|
| 290 |
+
model: Model to load weights into.
|
| 291 |
+
optimizer: Optimizer to load state into.
|
| 292 |
+
scheduler: Scheduler to load state into.
|
| 293 |
+
scaler: GradScaler to load state into.
|
| 294 |
+
device: Target device.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
Tuple of (start_epoch, best_f1).
|
| 298 |
+
"""
|
| 299 |
+
ckpt = torch.load(path, map_location=device)
|
| 300 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 301 |
+
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 302 |
+
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
| 303 |
+
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
| 304 |
+
logger.info("Resumed from epoch %d (best F1: %.4f)", ckpt["epoch"], ckpt["best_f1"])
|
| 305 |
+
return ckpt["epoch"], ckpt["best_f1"]
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def main() -> None:
|
| 309 |
+
"""Main training entry point."""
|
| 310 |
+
parser = argparse.ArgumentParser(description="Train change detection model")
|
| 311 |
+
parser.add_argument("--config", type=Path, default=Path("configs/config.yaml"))
|
| 312 |
+
parser.add_argument("--model", type=str, default=None, help="Override model name from config")
|
| 313 |
+
parser.add_argument("--resume", type=Path, default=None, help="Path to checkpoint for resume")
|
| 314 |
+
args = parser.parse_args()
|
| 315 |
+
|
| 316 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| 317 |
+
|
| 318 |
+
# Load config
|
| 319 |
+
with open(args.config, "r") as f:
|
| 320 |
+
config = yaml.safe_load(f)
|
| 321 |
+
|
| 322 |
+
model_name = args.model or config["model"]["name"]
|
| 323 |
+
seed = config.get("project", {}).get("seed", 42)
|
| 324 |
+
set_seed(seed)
|
| 325 |
+
|
| 326 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 327 |
+
logger.info("Device: %s", device)
|
| 328 |
+
|
| 329 |
+
# Resolve paths
|
| 330 |
+
paths = get_paths(config)
|
| 331 |
+
for p in paths.values():
|
| 332 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 333 |
+
|
| 334 |
+
# Model
|
| 335 |
+
model = get_model(model_name, config).to(device)
|
| 336 |
+
logger.info("Model: %s (%.1fM params)", model_name,
|
| 337 |
+
sum(p.numel() for p in model.parameters()) / 1e6)
|
| 338 |
+
|
| 339 |
+
# Data
|
| 340 |
+
batch_size = get_batch_size(config, model_name)
|
| 341 |
+
train_loader, val_loader = build_dataloaders(config, paths["data"], batch_size)
|
| 342 |
+
|
| 343 |
+
# Loss, optimizer, scheduler
|
| 344 |
+
criterion = get_loss(config)
|
| 345 |
+
lr = config.get("learning_rates", {}).get(model_name, config["training"]["learning_rate"])
|
| 346 |
+
epochs = config.get("epoch_counts", {}).get(model_name, config["training"]["epochs"])
|
| 347 |
+
|
| 348 |
+
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=config["training"]["weight_decay"])
|
| 349 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
|
| 350 |
+
scaler = GradScaler()
|
| 351 |
+
|
| 352 |
+
# TensorBoard
|
| 353 |
+
writer = SummaryWriter(log_dir=str(paths["logs"] / model_name))
|
| 354 |
+
|
| 355 |
+
# Resume
|
| 356 |
+
start_epoch = 0
|
| 357 |
+
best_f1 = 0.0
|
| 358 |
+
if args.resume and args.resume.exists():
|
| 359 |
+
start_epoch, best_f1 = load_checkpoint(
|
| 360 |
+
args.resume, model, optimizer, scheduler, scaler, device
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Early stopping state
|
| 364 |
+
es_cfg = config["training"]["early_stopping"]
|
| 365 |
+
patience = es_cfg.get("patience", 15)
|
| 366 |
+
patience_counter = 0
|
| 367 |
+
threshold = config.get("evaluation", {}).get("threshold", 0.5)
|
| 368 |
+
|
| 369 |
+
# Training loop
|
| 370 |
+
for epoch in range(start_epoch, epochs):
|
| 371 |
+
logger.info("Epoch %d/%d", epoch + 1, epochs)
|
| 372 |
+
|
| 373 |
+
train_loss, train_metrics = train_one_epoch(
|
| 374 |
+
model, train_loader, criterion, optimizer, scaler, device, config
|
| 375 |
+
)
|
| 376 |
+
val_loss, val_metrics = validate(model, val_loader, criterion, device, threshold)
|
| 377 |
+
scheduler.step()
|
| 378 |
+
|
| 379 |
+
# Log
|
| 380 |
+
writer.add_scalar("Loss/train", train_loss, epoch)
|
| 381 |
+
writer.add_scalar("Loss/val", val_loss, epoch)
|
| 382 |
+
for k, v in val_metrics.items():
|
| 383 |
+
writer.add_scalar(f"Val/{k}", v, epoch)
|
| 384 |
+
|
| 385 |
+
logger.info(
|
| 386 |
+
" Train Loss: %.4f | Val Loss: %.4f | Val F1: %.4f | Val IoU: %.4f",
|
| 387 |
+
train_loss, val_loss, val_metrics["f1"], val_metrics["iou"],
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Save last checkpoint (always)
|
| 391 |
+
save_checkpoint(
|
| 392 |
+
model, optimizer, scheduler, scaler, epoch + 1, best_f1,
|
| 393 |
+
paths["checkpoints"] / f"{model_name}_last.pth",
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# Save best checkpoint
|
| 397 |
+
if val_metrics["f1"] > best_f1:
|
| 398 |
+
best_f1 = val_metrics["f1"]
|
| 399 |
+
patience_counter = 0
|
| 400 |
+
save_checkpoint(
|
| 401 |
+
model, optimizer, scheduler, scaler, epoch + 1, best_f1,
|
| 402 |
+
paths["checkpoints"] / f"{model_name}_best.pth",
|
| 403 |
+
)
|
| 404 |
+
logger.info(" New best F1: %.4f", best_f1)
|
| 405 |
+
else:
|
| 406 |
+
patience_counter += 1
|
| 407 |
+
|
| 408 |
+
# Early stopping
|
| 409 |
+
if es_cfg.get("enabled", True) and patience_counter >= patience:
|
| 410 |
+
logger.info("Early stopping triggered at epoch %d", epoch + 1)
|
| 411 |
+
break
|
| 412 |
+
|
| 413 |
+
writer.close()
|
| 414 |
+
logger.info("Training complete. Best F1: %.4f", best_f1)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
main()
|
utils/__init__.py
ADDED
|
File without changes
|
utils/losses.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Loss functions for binary change detection.
|
| 2 |
+
|
| 3 |
+
Provides BCEDiceLoss (default) and FocalLoss, both operating on raw logits.
|
| 4 |
+
A factory function ``get_loss`` reads the project config and returns the
|
| 5 |
+
selected loss module.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BCEDiceLoss(nn.Module):
|
| 16 |
+
"""Combined Binary Cross-Entropy and Dice Loss.
|
| 17 |
+
|
| 18 |
+
Both components operate on raw logits — sigmoid is applied internally so
|
| 19 |
+
the caller should **not** pre-apply it.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
bce_weight: Scalar weight for the BCE component.
|
| 23 |
+
dice_weight: Scalar weight for the Dice component.
|
| 24 |
+
smooth: Smoothing constant for Dice to avoid division by zero.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
bce_weight: float = 0.5,
|
| 30 |
+
dice_weight: float = 0.5,
|
| 31 |
+
smooth: float = 1.0,
|
| 32 |
+
) -> None:
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.bce_weight = bce_weight
|
| 35 |
+
self.dice_weight = dice_weight
|
| 36 |
+
self.smooth = smooth
|
| 37 |
+
|
| 38 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
"""Compute the combined BCE + Dice loss.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
logits: Raw model output of shape ``[B, 1, H, W]``.
|
| 43 |
+
targets: Binary ground-truth masks of shape ``[B, 1, H, W]``
|
| 44 |
+
with values in {0, 1}.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Scalar loss tensor on the same device as the inputs.
|
| 48 |
+
"""
|
| 49 |
+
# --- BCE component (numerically stable, operates on logits) ---
|
| 50 |
+
bce_loss = F.binary_cross_entropy_with_logits(logits, targets)
|
| 51 |
+
|
| 52 |
+
# --- Dice component ---
|
| 53 |
+
probs = torch.sigmoid(logits)
|
| 54 |
+
# Flatten spatial dims per sample for stable dice computation
|
| 55 |
+
probs_flat = probs.view(probs.size(0), -1)
|
| 56 |
+
targets_flat = targets.view(targets.size(0), -1)
|
| 57 |
+
|
| 58 |
+
intersection = (probs_flat * targets_flat).sum(dim=1)
|
| 59 |
+
union = probs_flat.sum(dim=1) + targets_flat.sum(dim=1)
|
| 60 |
+
dice_score = (2.0 * intersection + self.smooth) / (union + self.smooth)
|
| 61 |
+
dice_loss = 1.0 - dice_score.mean()
|
| 62 |
+
|
| 63 |
+
return self.bce_weight * bce_loss + self.dice_weight * dice_loss
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class FocalLoss(nn.Module):
|
| 67 |
+
"""Focal Loss for addressing class imbalance in change detection.
|
| 68 |
+
|
| 69 |
+
Down-weights well-classified (easy) pixels so the model focuses on hard
|
| 70 |
+
examples near the decision boundary. Operates on raw logits.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
alpha: Balancing factor for the positive class (1 − alpha for negative).
|
| 74 |
+
gamma: Focusing exponent — higher values down-weight easy examples more.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, alpha: float = 0.25, gamma: float = 2.0) -> None:
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.alpha = alpha
|
| 80 |
+
self.gamma = gamma
|
| 81 |
+
|
| 82 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
"""Compute focal loss.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
logits: Raw model output of shape ``[B, 1, H, W]``.
|
| 87 |
+
targets: Binary ground-truth masks of shape ``[B, 1, H, W]``
|
| 88 |
+
with values in {0, 1}.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Scalar loss tensor on the same device as the inputs.
|
| 92 |
+
"""
|
| 93 |
+
# Per-pixel BCE (unreduced)
|
| 94 |
+
bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
|
| 95 |
+
|
| 96 |
+
probs = torch.sigmoid(logits)
|
| 97 |
+
# p_t = probability of the true class
|
| 98 |
+
p_t = probs * targets + (1.0 - probs) * (1.0 - targets)
|
| 99 |
+
# alpha_t = alpha for positives, (1-alpha) for negatives
|
| 100 |
+
alpha_t = self.alpha * targets + (1.0 - self.alpha) * (1.0 - targets)
|
| 101 |
+
|
| 102 |
+
focal_weight = alpha_t * (1.0 - p_t) ** self.gamma
|
| 103 |
+
return (focal_weight * bce).mean()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_loss(config: Dict[str, Any]) -> nn.Module:
|
| 107 |
+
"""Factory function — instantiate a loss module from the project config.
|
| 108 |
+
|
| 109 |
+
Reads ``config["loss"]["name"]`` to select the loss type and extracts
|
| 110 |
+
the matching sub-key for constructor arguments.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
config: Full project config dict (as loaded from ``config.yaml``).
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
An ``nn.Module`` loss function ready for ``loss(logits, targets)``.
|
| 117 |
+
|
| 118 |
+
Raises:
|
| 119 |
+
ValueError: If the requested loss name is not recognised.
|
| 120 |
+
"""
|
| 121 |
+
loss_cfg = config.get("loss", {})
|
| 122 |
+
name = loss_cfg.get("name", "bce_dice")
|
| 123 |
+
|
| 124 |
+
if name == "bce_dice":
|
| 125 |
+
params = loss_cfg.get("bce_dice", {})
|
| 126 |
+
return BCEDiceLoss(
|
| 127 |
+
bce_weight=params.get("bce_weight", 0.5),
|
| 128 |
+
dice_weight=params.get("dice_weight", 0.5),
|
| 129 |
+
)
|
| 130 |
+
elif name == "focal":
|
| 131 |
+
params = loss_cfg.get("focal", {})
|
| 132 |
+
return FocalLoss(
|
| 133 |
+
alpha=params.get("alpha", 0.25),
|
| 134 |
+
gamma=params.get("gamma", 2.0),
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"Unknown loss '{name}'. Choose from: bce_dice, focal"
|
| 139 |
+
)
|
utils/metrics.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation metrics for binary change detection.
|
| 2 |
+
|
| 3 |
+
Provides a ``ConfusionMatrix`` accumulator, standalone metric functions, and a
|
| 4 |
+
high-level ``MetricTracker`` that accepts raw logits and handles sigmoid +
|
| 5 |
+
thresholding internally.
|
| 6 |
+
|
| 7 |
+
All tensor operations stay on GPU until the final ``.item()`` call inside
|
| 8 |
+
``compute()`` so there is no unnecessary device transfer during the hot loop.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Dict
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# Small constant to prevent division-by-zero in metric formulas.
|
| 16 |
+
_EPS: float = 1e-7
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Low-level confusion-matrix accumulator
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
class ConfusionMatrix:
|
| 24 |
+
"""Accumulates TP / FP / FN / TN counts across batches.
|
| 25 |
+
|
| 26 |
+
Counts are kept as plain Python ints (moved off GPU via a single
|
| 27 |
+
``.item()`` per update call) so that accumulated values never overflow
|
| 28 |
+
a GPU scalar.
|
| 29 |
+
|
| 30 |
+
Example::
|
| 31 |
+
|
| 32 |
+
cm = ConfusionMatrix()
|
| 33 |
+
for preds, targets in loader:
|
| 34 |
+
cm.update(preds, targets)
|
| 35 |
+
metrics = cm.compute()
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self) -> None:
|
| 39 |
+
self.reset()
|
| 40 |
+
|
| 41 |
+
def reset(self) -> None:
|
| 42 |
+
"""Reset all counters to zero."""
|
| 43 |
+
self.tp: int = 0
|
| 44 |
+
self.fp: int = 0
|
| 45 |
+
self.fn: int = 0
|
| 46 |
+
self.tn: int = 0
|
| 47 |
+
|
| 48 |
+
def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None:
|
| 49 |
+
"""Accumulate one batch of binary predictions.
|
| 50 |
+
|
| 51 |
+
All boolean logic runs on whatever device the tensors live on; only
|
| 52 |
+
the four resulting scalars are moved to CPU via ``.item()``.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
preds: Binary predictions ``[B, 1, H, W]`` with values in {0, 1}.
|
| 56 |
+
targets: Ground-truth masks ``[B, 1, H, W]`` with values in {0, 1}.
|
| 57 |
+
"""
|
| 58 |
+
p = preds.bool().flatten()
|
| 59 |
+
t = targets.bool().flatten()
|
| 60 |
+
|
| 61 |
+
self.tp += (p & t).sum().item()
|
| 62 |
+
self.fp += (p & ~t).sum().item()
|
| 63 |
+
self.fn += (~p & t).sum().item()
|
| 64 |
+
self.tn += (~p & ~t).sum().item()
|
| 65 |
+
|
| 66 |
+
def compute(self) -> Dict[str, float]:
|
| 67 |
+
"""Derive all metrics from the accumulated counts.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Dict with keys ``'f1'``, ``'iou'``, ``'precision'``, ``'recall'``,
|
| 71 |
+
``'oa'`` — each a plain Python float.
|
| 72 |
+
"""
|
| 73 |
+
precision = self.tp / (self.tp + self.fp + _EPS)
|
| 74 |
+
recall = self.tp / (self.tp + self.fn + _EPS)
|
| 75 |
+
f1 = 2.0 * precision * recall / (precision + recall + _EPS)
|
| 76 |
+
iou = self.tp / (self.tp + self.fp + self.fn + _EPS)
|
| 77 |
+
oa = (self.tp + self.tn) / (self.tp + self.fp + self.fn + self.tn + _EPS)
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"f1": f1,
|
| 81 |
+
"iou": iou,
|
| 82 |
+
"precision": precision,
|
| 83 |
+
"recall": recall,
|
| 84 |
+
"oa": oa,
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
# Standalone convenience functions (single-batch, binary inputs)
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def _quick_cm(preds: torch.Tensor, targets: torch.Tensor) -> ConfusionMatrix:
|
| 93 |
+
"""Create and populate a ConfusionMatrix from a single batch.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
preds: Binary predictions ``[B, 1, H, W]``.
|
| 97 |
+
targets: Ground-truth masks ``[B, 1, H, W]``.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Populated ``ConfusionMatrix`` instance.
|
| 101 |
+
"""
|
| 102 |
+
cm = ConfusionMatrix()
|
| 103 |
+
cm.update(preds, targets)
|
| 104 |
+
return cm
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def compute_f1(preds: torch.Tensor, targets: torch.Tensor) -> float:
|
| 108 |
+
"""Compute F1 score for a single batch.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
preds: Binary predictions ``[B, 1, H, W]``.
|
| 112 |
+
targets: Ground-truth masks ``[B, 1, H, W]``.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
F1 score as a float in [0, 1].
|
| 116 |
+
"""
|
| 117 |
+
return _quick_cm(preds, targets).compute()["f1"]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def compute_iou(preds: torch.Tensor, targets: torch.Tensor) -> float:
|
| 121 |
+
"""Compute IoU (Jaccard index) for a single batch.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
preds: Binary predictions ``[B, 1, H, W]``.
|
| 125 |
+
targets: Ground-truth masks ``[B, 1, H, W]``.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
IoU score as a float in [0, 1].
|
| 129 |
+
"""
|
| 130 |
+
return _quick_cm(preds, targets).compute()["iou"]
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def compute_precision(preds: torch.Tensor, targets: torch.Tensor) -> float:
|
| 134 |
+
"""Compute precision for a single batch.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
preds: Binary predictions ``[B, 1, H, W]``.
|
| 138 |
+
targets: Ground-truth masks ``[B, 1, H, W]``.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Precision score as a float in [0, 1].
|
| 142 |
+
"""
|
| 143 |
+
return _quick_cm(preds, targets).compute()["precision"]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def compute_recall(preds: torch.Tensor, targets: torch.Tensor) -> float:
|
| 147 |
+
"""Compute recall for a single batch.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
preds: Binary predictions ``[B, 1, H, W]``.
|
| 151 |
+
targets: Ground-truth masks ``[B, 1, H, W]``.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Recall score as a float in [0, 1].
|
| 155 |
+
"""
|
| 156 |
+
return _quick_cm(preds, targets).compute()["recall"]
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def compute_oa(preds: torch.Tensor, targets: torch.Tensor) -> float:
|
| 160 |
+
"""Compute overall accuracy for a single batch.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
preds: Binary predictions ``[B, 1, H, W]``.
|
| 164 |
+
targets: Ground-truth masks ``[B, 1, H, W]``.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Overall accuracy as a float in [0, 1].
|
| 168 |
+
"""
|
| 169 |
+
return _quick_cm(preds, targets).compute()["oa"]
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
# High-level tracker (accepts raw logits)
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
|
| 176 |
+
class MetricTracker:
|
| 177 |
+
"""End-to-end metric tracker for training / validation loops.
|
| 178 |
+
|
| 179 |
+
Wraps a ``ConfusionMatrix`` and transparently applies sigmoid +
|
| 180 |
+
thresholding to raw model logits before accumulating counts.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
threshold: Decision threshold applied after sigmoid (default 0.5).
|
| 184 |
+
|
| 185 |
+
Example::
|
| 186 |
+
|
| 187 |
+
tracker = MetricTracker(threshold=0.5)
|
| 188 |
+
for batch in val_loader:
|
| 189 |
+
logits = model(batch["A"], batch["B"])
|
| 190 |
+
tracker.update(logits, batch["mask"])
|
| 191 |
+
results = tracker.compute() # {"f1": ..., "iou": ..., ...}
|
| 192 |
+
tracker.reset()
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
def __init__(self, threshold: float = 0.5) -> None:
|
| 196 |
+
self.threshold = threshold
|
| 197 |
+
self.cm = ConfusionMatrix()
|
| 198 |
+
|
| 199 |
+
def reset(self) -> None:
|
| 200 |
+
"""Reset the internal confusion matrix."""
|
| 201 |
+
self.cm.reset()
|
| 202 |
+
|
| 203 |
+
@torch.no_grad()
|
| 204 |
+
def update(self, logits: torch.Tensor, targets: torch.Tensor) -> None:
|
| 205 |
+
"""Apply sigmoid + threshold and accumulate counts.
|
| 206 |
+
|
| 207 |
+
This method is wrapped with ``@torch.no_grad()`` so it can be
|
| 208 |
+
called safely inside a validation loop without affecting autograd.
|
| 209 |
+
All operations run on the input tensor's device.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
logits: Raw model output ``[B, 1, H, W]`` (pre-sigmoid).
|
| 213 |
+
targets: Binary ground-truth masks ``[B, 1, H, W]`` with
|
| 214 |
+
values in {0, 1}.
|
| 215 |
+
"""
|
| 216 |
+
preds = (torch.sigmoid(logits) >= self.threshold).float()
|
| 217 |
+
self.cm.update(preds, targets)
|
| 218 |
+
|
| 219 |
+
def compute(self) -> Dict[str, float]:
|
| 220 |
+
"""Compute all metrics from accumulated counts.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Dict with keys ``'f1'``, ``'iou'``, ``'precision'``, ``'recall'``,
|
| 224 |
+
``'oa'``.
|
| 225 |
+
"""
|
| 226 |
+
return self.cm.compute()
|
utils/visualization.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Visualization utilities for change detection results.
|
| 2 |
+
|
| 3 |
+
Provides functions to plot predictions, overlay change maps, and track
|
| 4 |
+
training metrics over time.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def denormalize(
|
| 16 |
+
img: np.ndarray,
|
| 17 |
+
mean: tuple = (0.485, 0.456, 0.406),
|
| 18 |
+
std: tuple = (0.229, 0.224, 0.225),
|
| 19 |
+
) -> np.ndarray:
|
| 20 |
+
"""Reverse ImageNet normalization for display.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
img: Normalized image array [H, W, 3].
|
| 24 |
+
mean: Channel means used for normalization.
|
| 25 |
+
std: Channel stds used for normalization.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Denormalized image clipped to [0, 1].
|
| 29 |
+
"""
|
| 30 |
+
img = img * np.array(std) + np.array(mean)
|
| 31 |
+
return np.clip(img, 0, 1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def plot_prediction(
|
| 35 |
+
img_a: torch.Tensor,
|
| 36 |
+
img_b: torch.Tensor,
|
| 37 |
+
mask_gt: torch.Tensor,
|
| 38 |
+
mask_pred: torch.Tensor,
|
| 39 |
+
save_path: Optional[Path] = None,
|
| 40 |
+
) -> plt.Figure:
|
| 41 |
+
"""Plot a single change detection prediction.
|
| 42 |
+
|
| 43 |
+
Shows: Before | After | Ground Truth | Prediction in a 1x4 grid.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
img_a: Before image tensor [3, H, W] (normalized).
|
| 47 |
+
img_b: After image tensor [3, H, W] (normalized).
|
| 48 |
+
mask_gt: Ground truth mask [1, H, W] (binary).
|
| 49 |
+
mask_pred: Predicted mask [1, H, W] (binary or probability).
|
| 50 |
+
save_path: Optional path to save the figure.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Matplotlib figure.
|
| 54 |
+
"""
|
| 55 |
+
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
|
| 56 |
+
|
| 57 |
+
# Convert tensors to numpy
|
| 58 |
+
a = denormalize(img_a.permute(1, 2, 0).cpu().numpy())
|
| 59 |
+
b = denormalize(img_b.permute(1, 2, 0).cpu().numpy())
|
| 60 |
+
gt = mask_gt.squeeze(0).cpu().numpy()
|
| 61 |
+
pred = mask_pred.squeeze(0).cpu().numpy()
|
| 62 |
+
|
| 63 |
+
titles = ["Before (A)", "After (B)", "Ground Truth", "Prediction"]
|
| 64 |
+
images = [a, b, gt, pred]
|
| 65 |
+
cmaps = [None, None, "gray", "gray"]
|
| 66 |
+
|
| 67 |
+
for ax, img, title, cmap in zip(axes, images, titles, cmaps):
|
| 68 |
+
ax.imshow(img, cmap=cmap, vmin=0, vmax=1)
|
| 69 |
+
ax.set_title(title)
|
| 70 |
+
ax.axis("off")
|
| 71 |
+
|
| 72 |
+
plt.tight_layout()
|
| 73 |
+
|
| 74 |
+
if save_path is not None:
|
| 75 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 76 |
+
|
| 77 |
+
return fig
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def overlay_changes(
|
| 81 |
+
img_b: torch.Tensor,
|
| 82 |
+
mask_pred: torch.Tensor,
|
| 83 |
+
alpha: float = 0.4,
|
| 84 |
+
color: tuple = (1.0, 0.0, 0.0),
|
| 85 |
+
) -> np.ndarray:
|
| 86 |
+
"""Overlay predicted change mask on the after image.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
img_b: After image tensor [3, H, W] (normalized).
|
| 90 |
+
mask_pred: Predicted binary mask [1, H, W].
|
| 91 |
+
alpha: Overlay transparency.
|
| 92 |
+
color: RGB color for the overlay (default: red).
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Overlaid image as numpy array [H, W, 3].
|
| 96 |
+
"""
|
| 97 |
+
b = denormalize(img_b.permute(1, 2, 0).cpu().numpy())
|
| 98 |
+
mask = mask_pred.squeeze(0).cpu().numpy()
|
| 99 |
+
|
| 100 |
+
overlay = b.copy()
|
| 101 |
+
for c in range(3):
|
| 102 |
+
overlay[:, :, c] = np.where(
|
| 103 |
+
mask > 0.5,
|
| 104 |
+
b[:, :, c] * (1 - alpha) + color[c] * alpha,
|
| 105 |
+
b[:, :, c],
|
| 106 |
+
)
|
| 107 |
+
return overlay
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def plot_metrics_history(
|
| 111 |
+
history: Dict[str, List[float]],
|
| 112 |
+
save_path: Optional[Path] = None,
|
| 113 |
+
) -> plt.Figure:
|
| 114 |
+
"""Plot training metric curves over epochs.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
history: Dict mapping metric names to lists of per-epoch values.
|
| 118 |
+
save_path: Optional path to save the figure.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Matplotlib figure.
|
| 122 |
+
"""
|
| 123 |
+
n_metrics = len(history)
|
| 124 |
+
fig, axes = plt.subplots(1, n_metrics, figsize=(5 * n_metrics, 4))
|
| 125 |
+
|
| 126 |
+
if n_metrics == 1:
|
| 127 |
+
axes = [axes]
|
| 128 |
+
|
| 129 |
+
for ax, (name, values) in zip(axes, history.items()):
|
| 130 |
+
ax.plot(values, marker="o", markersize=2)
|
| 131 |
+
ax.set_title(name)
|
| 132 |
+
ax.set_xlabel("Epoch")
|
| 133 |
+
ax.set_ylabel(name)
|
| 134 |
+
ax.grid(True, alpha=0.3)
|
| 135 |
+
|
| 136 |
+
plt.tight_layout()
|
| 137 |
+
|
| 138 |
+
if save_path is not None:
|
| 139 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 140 |
+
|
| 141 |
+
return fig
|