Vedant Jigarbhai Mehta commited on
Commit
b25c087
·
0 Parent(s):

Initial scaffolding for military base change detection project

Browse files

Add complete project structure with 3 model architectures (Siamese CNN,
UNet++, ChangeFormer), dataset pipeline, training/evaluation/inference
scripts, Gradio demo, Colab setup, and config with all hyperparameters.

.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Checkpoints & model weights
2
+ checkpoints/
3
+ *.pth
4
+ *.pt
5
+
6
+ # Logs
7
+ logs/
8
+
9
+ # Outputs
10
+ outputs/
11
+
12
+ # Data
13
+ raw_data/
14
+ processed_data/
15
+
16
+ # Python
17
+ __pycache__/
18
+ *.pyc
19
+ *.pyo
20
+ *.egg-info/
21
+ dist/
22
+ build/
23
+ .eggs/
24
+
25
+ # Environment
26
+ .env
27
+ .venv/
28
+ venv/
29
+ env/
30
+
31
+ # IDE
32
+ .vscode/
33
+ .idea/
34
+ *.swp
35
+ *.swo
36
+
37
+ # OS
38
+ .DS_Store
39
+ Thumbs.db
40
+
41
+ # Jupyter
42
+ .ipynb_checkpoints/
README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Military Base Construction Monitoring — Change Detection
2
+
3
+ Deep learning system for detecting new structures and infrastructure changes between satellite image pairs. Targets defense applications: military base expansion, runway construction, and infrastructure development monitoring.
4
+
5
+ ## Models
6
+
7
+ | Model | Backbone | Role | Paper |
8
+ |---|---|---|---|
9
+ | Siamese CNN | ResNet18 (shared) | Baseline | — |
10
+ | UNet++ | ResNet34 (shared) | Mid-tier | [arXiv:1807.10165](https://arxiv.org/abs/1807.10165) |
11
+ | ChangeFormer | MiT-B1 (shared) | SOTA | [arXiv:2201.01293](https://arxiv.org/abs/2201.01293) |
12
+
13
+ ## Dataset
14
+
15
+ **LEVIR-CD** — 637 image pairs at 1024×1024, cropped to 256×256 non-overlapping patches. Contains building change annotations across urban areas.
16
+
17
+ ## Quick Start (Google Colab)
18
+
19
+ ```python
20
+ # 1. Setup
21
+ from setup_colab import setup
22
+ dirs = setup()
23
+
24
+ # 2. Train
25
+ !python train.py --config configs/config.yaml --model siamese_cnn
26
+
27
+ # 3. Evaluate
28
+ !python evaluate.py --config configs/config.yaml --checkpoint checkpoints/siamese_cnn_best.pth
29
+
30
+ # 4. Resume after disconnect
31
+ !python train.py --config configs/config.yaml --model changeformer \
32
+ --resume /content/drive/MyDrive/change-detection/checkpoints/changeformer_last.pth
33
+ ```
34
+
35
+ ## Local Usage
36
+
37
+ ```bash
38
+ # Preprocess dataset
39
+ python data/download.py --dataset levir-cd --raw_dir ./raw_data --out_dir ./processed_data
40
+
41
+ # Train
42
+ python train.py --config configs/config.yaml --model unet_pp
43
+
44
+ # Evaluate
45
+ python evaluate.py --config configs/config.yaml --checkpoint checkpoints/unet_pp_best.pth
46
+
47
+ # Inference on new image pair
48
+ python inference.py --before path/to/before.png --after path/to/after.png \
49
+ --model changeformer --checkpoint checkpoints/changeformer_best.pth
50
+
51
+ # Gradio demo
52
+ python app.py
53
+ ```
54
+
55
+ ## GPU Batch Sizes (Auto-Detected)
56
+
57
+ | Model | T4 (16GB) | V100 (16GB) | LR |
58
+ |---|---|---|---|
59
+ | Siamese CNN | 16 | 16 | 1e-3 |
60
+ | UNet++ | 8 | 12 | 1e-4 |
61
+ | ChangeFormer | 4 | 6 | 6e-5 |
62
+
63
+ ## Evaluation Metrics
64
+
65
+ - **F1-Score** (primary, used for model selection and early stopping)
66
+ - IoU / Jaccard
67
+ - Precision, Recall
68
+ - Overall Accuracy
69
+
70
+ ## Project Structure
71
+
72
+ ```
73
+ military-base-change-detection/
74
+ ├── configs/config.yaml # All hyperparameters and paths
75
+ ├── data/
76
+ │ ├── download.py # Dataset download & patch cropping
77
+ │ └── dataset.py # PyTorch Dataset with synced augmentations
78
+ ├── models/
79
+ │ ├── __init__.py # get_model() factory
80
+ │ ├── siamese_cnn.py # Siamese CNN baseline
81
+ │ ├── unet_pp.py # UNet++ change detection
82
+ │ └── changeformer.py # ChangeFormer transformer
83
+ ├── utils/
84
+ │ ├── metrics.py # F1, IoU, Precision, Recall, OA
85
+ │ ├── losses.py # BCEDiceLoss, FocalLoss
86
+ │ └── visualization.py # Plotting utilities
87
+ ├── train.py # Training with AMP, early stopping, resume
88
+ ├── evaluate.py # Test set evaluation
89
+ ├── inference.py # Inference on new image pairs
90
+ ├── app.py # Gradio demo
91
+ ├── setup_colab.py # Colab environment setup
92
+ └── requirements.txt # Pinned dependencies
93
+ ```
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio web demo for change detection inference.
2
+
3
+ Provides an interactive interface to upload before/after satellite image pairs
4
+ and visualize predicted change masks with overlays.
5
+
6
+ Usage:
7
+ python app.py
8
+ """
9
+
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Optional, Tuple
13
+
14
+ import cv2
15
+ import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ import yaml
19
+
20
+ from data.dataset import IMAGENET_MEAN, IMAGENET_STD
21
+ from inference import preprocess_image, sliding_window_inference
22
+ from models import get_model
23
+ from utils.visualization import denormalize, overlay_changes
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Global model cache
28
+ _model: Optional[torch.nn.Module] = None
29
+ _model_name: Optional[str] = None
30
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ _config = None
32
+
33
+
34
+ def load_config() -> dict:
35
+ """Load project config from YAML.
36
+
37
+ Returns:
38
+ Config dictionary.
39
+ """
40
+ config_path = Path("configs/config.yaml")
41
+ with open(config_path, "r") as f:
42
+ return yaml.safe_load(f)
43
+
44
+
45
+ def load_model(model_name: str, checkpoint_path: str) -> torch.nn.Module:
46
+ """Load a change detection model with caching.
47
+
48
+ Args:
49
+ model_name: Name of the model architecture.
50
+ checkpoint_path: Path to the model checkpoint.
51
+
52
+ Returns:
53
+ Loaded model in eval mode.
54
+ """
55
+ global _model, _model_name, _config
56
+
57
+ if _config is None:
58
+ _config = load_config()
59
+
60
+ if _model is not None and _model_name == model_name:
61
+ return _model
62
+
63
+ model = get_model(model_name, _config).to(_device)
64
+ ckpt = torch.load(checkpoint_path, map_location=_device)
65
+ model.load_state_dict(ckpt["model_state_dict"])
66
+ model.eval()
67
+
68
+ _model = model
69
+ _model_name = model_name
70
+ logger.info("Loaded model: %s from %s", model_name, checkpoint_path)
71
+ return model
72
+
73
+
74
+ def predict(
75
+ before_image: np.ndarray,
76
+ after_image: np.ndarray,
77
+ model_name: str,
78
+ checkpoint_path: str,
79
+ threshold: float,
80
+ ) -> Tuple[np.ndarray, np.ndarray]:
81
+ """Run change detection on a pair of images.
82
+
83
+ Args:
84
+ before_image: Before image as numpy array (RGB, uint8).
85
+ after_image: After image as numpy array (RGB, uint8).
86
+ model_name: Model architecture name.
87
+ checkpoint_path: Path to model weights.
88
+ threshold: Binarization threshold.
89
+
90
+ Returns:
91
+ Tuple of (binary change mask, overlay visualization).
92
+ """
93
+ model = load_model(model_name, checkpoint_path)
94
+ patch_size = 256
95
+
96
+ # Preprocess both images
97
+ def _to_tensor(img: np.ndarray) -> torch.Tensor:
98
+ h, w = img.shape[:2]
99
+ pad_h = (patch_size - h % patch_size) % patch_size
100
+ pad_w = (patch_size - w % patch_size) % patch_size
101
+ if pad_h > 0 or pad_w > 0:
102
+ img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
103
+ img_f = img.astype(np.float32) / 255.0
104
+ mean = np.array(IMAGENET_MEAN, dtype=np.float32)
105
+ std = np.array(IMAGENET_STD, dtype=np.float32)
106
+ img_f = (img_f - mean) / std
107
+ return torch.from_numpy(img_f).permute(2, 0, 1).unsqueeze(0).float()
108
+
109
+ orig_h, orig_w = before_image.shape[:2]
110
+ tensor_a = _to_tensor(before_image)
111
+ tensor_b = _to_tensor(after_image)
112
+
113
+ # Run inference
114
+ prob_map = sliding_window_inference(model, tensor_a, tensor_b, patch_size, _device)
115
+ prob_map = prob_map[:, :, :orig_h, :orig_w]
116
+
117
+ # Binary mask
118
+ mask_np = prob_map.squeeze().numpy()
119
+ binary_mask = (mask_np > threshold).astype(np.uint8) * 255
120
+
121
+ # Overlay on after image
122
+ overlay = after_image.copy().astype(np.float32) / 255.0
123
+ change_pixels = mask_np > threshold
124
+ overlay[change_pixels, 0] = np.clip(overlay[change_pixels, 0] * 0.6 + 0.4, 0, 1)
125
+ overlay[change_pixels, 1] = overlay[change_pixels, 1] * 0.6
126
+ overlay[change_pixels, 2] = overlay[change_pixels, 2] * 0.6
127
+ overlay = (overlay * 255).astype(np.uint8)
128
+
129
+ return binary_mask, overlay
130
+
131
+
132
+ def build_demo() -> gr.Blocks:
133
+ """Build the Gradio demo interface.
134
+
135
+ Returns:
136
+ Gradio Blocks application.
137
+ """
138
+ config = load_config()
139
+ gradio_cfg = config.get("gradio", {})
140
+
141
+ with gr.Blocks(title="Military Base Change Detection") as demo:
142
+ gr.Markdown("# Military Base Change Detection")
143
+ gr.Markdown("Upload before/after satellite image pairs to detect construction and infrastructure changes.")
144
+
145
+ with gr.Row():
146
+ with gr.Column():
147
+ before_img = gr.Image(label="Before Image", type="numpy")
148
+ after_img = gr.Image(label="After Image", type="numpy")
149
+ with gr.Column():
150
+ change_mask = gr.Image(label="Change Mask")
151
+ overlay_img = gr.Image(label="Overlay")
152
+
153
+ with gr.Row():
154
+ model_dropdown = gr.Dropdown(
155
+ choices=["siamese_cnn", "unet_pp", "changeformer"],
156
+ value=gradio_cfg.get("default_model", "unet_pp"),
157
+ label="Model",
158
+ )
159
+ checkpoint_input = gr.Textbox(
160
+ value=gradio_cfg.get("default_checkpoint", "checkpoints/unet_pp_best.pth"),
161
+ label="Checkpoint Path",
162
+ )
163
+ threshold_slider = gr.Slider(
164
+ minimum=0.1, maximum=0.9, value=0.5, step=0.05,
165
+ label="Detection Threshold",
166
+ )
167
+
168
+ detect_btn = gr.Button("Detect Changes", variant="primary")
169
+ detect_btn.click(
170
+ fn=predict,
171
+ inputs=[before_img, after_img, model_dropdown, checkpoint_input, threshold_slider],
172
+ outputs=[change_mask, overlay_img],
173
+ )
174
+
175
+ return demo
176
+
177
+
178
+ def main() -> None:
179
+ """Launch the Gradio demo."""
180
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
181
+
182
+ config = load_config()
183
+ gradio_cfg = config.get("gradio", {})
184
+
185
+ demo = build_demo()
186
+ demo.launch(
187
+ server_port=gradio_cfg.get("server_port", 7860),
188
+ share=gradio_cfg.get("share", False),
189
+ )
190
+
191
+
192
+ if __name__ == "__main__":
193
+ main()
configs/config.yaml ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Military Base Change Detection — Master Configuration
3
+ # =============================================================================
4
+
5
+ # --- Project paths ---
6
+ project:
7
+ name: "military-base-change-detection"
8
+ seed: 42
9
+
10
+ # --- Colab / runtime settings ---
11
+ colab:
12
+ enabled: true
13
+ drive_root: "/content/drive/MyDrive/change-detection"
14
+ checkpoint_dir: "/content/drive/MyDrive/change-detection/checkpoints"
15
+ log_dir: "/content/drive/MyDrive/change-detection/logs"
16
+ output_dir: "/content/drive/MyDrive/change-detection/outputs"
17
+ data_dir: "/content/drive/MyDrive/change-detection/processed_data"
18
+
19
+ # --- Local paths (used when colab.enabled is false) ---
20
+ paths:
21
+ raw_data: "./raw_data"
22
+ processed_data: "./processed_data"
23
+ checkpoint_dir: "./checkpoints"
24
+ log_dir: "./logs"
25
+ output_dir: "./outputs"
26
+
27
+ # --- Dataset ---
28
+ dataset:
29
+ name: "levir-cd" # levir-cd | whu-cd
30
+ original_size: 1024
31
+ patch_size: 256
32
+ num_workers: 4
33
+ pin_memory: true
34
+ # ImageNet normalization
35
+ mean: [0.485, 0.456, 0.406]
36
+ std: [0.229, 0.224, 0.225]
37
+
38
+ # --- Augmentation (train only) ---
39
+ augmentation:
40
+ enabled: true
41
+ horizontal_flip: 0.5
42
+ vertical_flip: 0.5
43
+ random_rotate_90: 0.5
44
+ color_jitter:
45
+ brightness: 0.2
46
+ contrast: 0.2
47
+ saturation: 0.1
48
+ hue: 0.05
49
+
50
+ # --- Model selection ---
51
+ model:
52
+ name: "unet_pp" # siamese_cnn | unet_pp | changeformer
53
+
54
+ # --- Model-specific configs ---
55
+ siamese_cnn:
56
+ backbone: "resnet18"
57
+ pretrained: true
58
+
59
+ unet_pp:
60
+ encoder_name: "resnet34"
61
+ pretrained: true
62
+ deep_supervision: false
63
+
64
+ changeformer:
65
+ embed_dims: [64, 128, 320, 512] # MiT-B1 style
66
+ num_heads: [1, 2, 5, 8]
67
+ mlp_ratios: [8, 8, 4, 4]
68
+ depths: [2, 2, 2, 2]
69
+ pretrained_backbone: true
70
+
71
+ # --- Training ---
72
+ training:
73
+ epochs: 100 # 200 for changeformer
74
+ optimizer: "adamw"
75
+ learning_rate: 1.0e-4
76
+ weight_decay: 0.01
77
+ scheduler: "cosine"
78
+ warmup_epochs: 5
79
+ grad_clip_max_norm: 1.0
80
+ gradient_accumulation_steps: 1 # set to 2 for changeformer on T4
81
+ amp: true # mixed precision
82
+ early_stopping:
83
+ enabled: true
84
+ patience: 15
85
+ metric: "f1"
86
+ mode: "max"
87
+ log_interval: 10 # log every N batches
88
+ vis_interval: 5 # visualize predictions every N epochs
89
+
90
+ # --- Loss ---
91
+ loss:
92
+ name: "bce_dice" # bce_dice | focal
93
+ bce_dice:
94
+ bce_weight: 0.5
95
+ dice_weight: 0.5
96
+ focal:
97
+ alpha: 0.25
98
+ gamma: 2.0
99
+
100
+ # --- Evaluation ---
101
+ evaluation:
102
+ threshold: 0.5
103
+ metrics:
104
+ - f1
105
+ - iou
106
+ - precision
107
+ - recall
108
+ - oa
109
+
110
+ # --- GPU-specific batch sizes (auto-detected on Colab) ---
111
+ # model_name -> { gpu_type -> batch_size }
112
+ batch_sizes:
113
+ siamese_cnn:
114
+ T4: 16
115
+ V100: 16
116
+ default: 8
117
+ unet_pp:
118
+ T4: 8
119
+ V100: 12
120
+ default: 4
121
+ changeformer:
122
+ T4: 4
123
+ V100: 6
124
+ default: 2
125
+
126
+ # --- Per-model learning rates ---
127
+ learning_rates:
128
+ siamese_cnn: 1.0e-3
129
+ unet_pp: 1.0e-4
130
+ changeformer: 6.0e-5
131
+
132
+ # --- Per-model epoch counts ---
133
+ epoch_counts:
134
+ siamese_cnn: 100
135
+ unet_pp: 100
136
+ changeformer: 200
137
+
138
+ # --- Gradio demo ---
139
+ gradio:
140
+ server_port: 7860
141
+ share: false
142
+ default_model: "unet_pp"
143
+ default_checkpoint: "checkpoints/unet_pp_best.pth"
data/dataset.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch Dataset for change detection tasks.
2
+
3
+ Loads pre-cropped 256x256 image patches (before/after) and binary change masks.
4
+ Supports synchronized augmentations via albumentations.ReplayCompose.
5
+ """
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional, Tuple
10
+
11
+ import albumentations as A
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ from torch.utils.data import Dataset
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # ImageNet normalization constants
20
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
21
+ IMAGENET_STD = (0.229, 0.224, 0.225)
22
+
23
+
24
+ def get_train_transforms(config: Dict[str, Any]) -> A.ReplayCompose:
25
+ """Build training augmentation pipeline with synchronized transforms.
26
+
27
+ Args:
28
+ config: Augmentation config dict from config.yaml.
29
+
30
+ Returns:
31
+ ReplayCompose that applies identical spatial transforms to A, B, and mask.
32
+ """
33
+ aug_cfg = config.get("augmentation", {})
34
+ transforms = []
35
+
36
+ if aug_cfg.get("horizontal_flip", 0) > 0:
37
+ transforms.append(A.HorizontalFlip(p=aug_cfg["horizontal_flip"]))
38
+
39
+ if aug_cfg.get("vertical_flip", 0) > 0:
40
+ transforms.append(A.VerticalFlip(p=aug_cfg["vertical_flip"]))
41
+
42
+ if aug_cfg.get("random_rotate_90", 0) > 0:
43
+ transforms.append(A.RandomRotate90(p=aug_cfg["random_rotate_90"]))
44
+
45
+ transforms.append(A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD))
46
+
47
+ return A.ReplayCompose(
48
+ transforms,
49
+ additional_targets={"image_b": "image", "mask": "mask"},
50
+ )
51
+
52
+
53
+ def get_val_transforms() -> A.Compose:
54
+ """Build validation/test transform pipeline (normalize only).
55
+
56
+ Returns:
57
+ Compose with ImageNet normalization only.
58
+ """
59
+ return A.Compose(
60
+ [A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)],
61
+ additional_targets={"image_b": "image"},
62
+ )
63
+
64
+
65
+ class ChangeDetectionDataset(Dataset):
66
+ """Dataset for loading change detection image pairs and masks.
67
+
68
+ Expects directory structure:
69
+ root/
70
+ ├── A/ # before images
71
+ ├── B/ # after images
72
+ └── label/ # binary change masks (0=no change, 255=change)
73
+
74
+ Args:
75
+ root: Path to the split directory (e.g., processed_data/train).
76
+ split: One of 'train', 'val', 'test'.
77
+ config: Full config dict for augmentation settings.
78
+ transform: Optional override for the transform pipeline.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ root: Path,
84
+ split: str = "train",
85
+ config: Optional[Dict[str, Any]] = None,
86
+ transform: Optional[Any] = None,
87
+ ) -> None:
88
+ self.root = Path(root)
89
+ self.split = split
90
+
91
+ self.dir_a = self.root / "A"
92
+ self.dir_b = self.root / "B"
93
+ self.dir_label = self.root / "label"
94
+
95
+ # Collect sorted file lists
96
+ self.filenames = sorted([f.name for f in self.dir_a.iterdir() if f.suffix in (".png", ".jpg", ".tif")])
97
+ logger.info("Loaded %d samples for split '%s' from %s", len(self.filenames), split, root)
98
+
99
+ # Set up transforms
100
+ if transform is not None:
101
+ self.transform = transform
102
+ elif split == "train" and config is not None:
103
+ self.transform = get_train_transforms(config)
104
+ else:
105
+ self.transform = get_val_transforms()
106
+
107
+ def __len__(self) -> int:
108
+ return len(self.filenames)
109
+
110
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
111
+ """Load a single sample.
112
+
113
+ Args:
114
+ idx: Sample index.
115
+
116
+ Returns:
117
+ Dict with keys 'A', 'B', 'mask', 'filename'.
118
+ - A: before image tensor [3, H, W]
119
+ - B: after image tensor [3, H, W]
120
+ - mask: binary change mask tensor [1, H, W] (float, 0 or 1)
121
+ - filename: original filename string
122
+ """
123
+ fname = self.filenames[idx]
124
+
125
+ # Lazy load — read from disk each time (no RAM caching)
126
+ img_a = cv2.imread(str(self.dir_a / fname), cv2.IMREAD_COLOR)
127
+ img_a = cv2.cvtColor(img_a, cv2.COLOR_BGR2RGB)
128
+
129
+ img_b = cv2.imread(str(self.dir_b / fname), cv2.IMREAD_COLOR)
130
+ img_b = cv2.cvtColor(img_b, cv2.COLOR_BGR2RGB)
131
+
132
+ mask = cv2.imread(str(self.dir_label / fname), cv2.IMREAD_GRAYSCALE)
133
+ # Normalize 0/255 -> 0/1
134
+ mask = (mask / 255.0).astype(np.float32)
135
+
136
+ # Apply synchronized augmentations
137
+ if isinstance(self.transform, A.ReplayCompose):
138
+ transformed = self.transform(image=img_a, image_b=img_b, mask=mask)
139
+ img_a = transformed["image"]
140
+ img_b = transformed["image_b"]
141
+ mask = transformed["mask"]
142
+ else:
143
+ transformed = self.transform(image=img_a, image_b=img_b)
144
+ img_a = transformed["image"]
145
+ img_b = transformed["image_b"]
146
+ # Normalize only applied to images, mask stays as-is
147
+
148
+ # HWC -> CHW for images, add channel dim for mask
149
+ img_a = torch.from_numpy(img_a).permute(2, 0, 1).float()
150
+ img_b = torch.from_numpy(img_b).permute(2, 0, 1).float()
151
+ mask = torch.from_numpy(mask).unsqueeze(0).float()
152
+
153
+ return {"A": img_a, "B": img_b, "mask": mask, "filename": fname}
data/download.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Download and preprocess change detection datasets.
2
+
3
+ Supports LEVIR-CD and WHU-CD datasets. Downloads raw data, crops 1024x1024
4
+ images into 256x256 non-overlapping patches, and organizes into train/val/test
5
+ splits.
6
+
7
+ Usage:
8
+ python data/download.py --dataset levir-cd --raw_dir ./raw_data --out_dir ./processed_data
9
+ """
10
+
11
+ import argparse
12
+ import logging
13
+ from pathlib import Path
14
+ from typing import Tuple
15
+
16
+ import cv2
17
+ import numpy as np
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def download_levir_cd(raw_dir: Path) -> None:
23
+ """Download the LEVIR-CD dataset.
24
+
25
+ Args:
26
+ raw_dir: Directory to save the raw downloaded files.
27
+ """
28
+ # TODO: Implement download via gdown or direct URL
29
+ raise NotImplementedError("LEVIR-CD download not yet implemented")
30
+
31
+
32
+ def download_whu_cd(raw_dir: Path) -> None:
33
+ """Download the WHU-CD dataset.
34
+
35
+ Args:
36
+ raw_dir: Directory to save the raw downloaded files.
37
+ """
38
+ # TODO: Implement download
39
+ raise NotImplementedError("WHU-CD download not yet implemented")
40
+
41
+
42
+ def crop_to_patches(
43
+ image: np.ndarray,
44
+ patch_size: int = 256,
45
+ ) -> list[np.ndarray]:
46
+ """Crop an image into non-overlapping patches.
47
+
48
+ Args:
49
+ image: Input image of shape (H, W) or (H, W, C).
50
+ patch_size: Size of each square patch.
51
+
52
+ Returns:
53
+ List of cropped patches.
54
+ """
55
+ h, w = image.shape[:2]
56
+ patches = []
57
+ for y in range(0, h - patch_size + 1, patch_size):
58
+ for x in range(0, w - patch_size + 1, patch_size):
59
+ patch = image[y : y + patch_size, x : x + patch_size]
60
+ patches.append(patch)
61
+ return patches
62
+
63
+
64
+ def process_split(
65
+ raw_dir: Path,
66
+ out_dir: Path,
67
+ split: str,
68
+ patch_size: int = 256,
69
+ ) -> int:
70
+ """Process a single dataset split (train/val/test).
71
+
72
+ Reads image pairs and masks from raw_dir, crops into patches, and
73
+ saves to out_dir.
74
+
75
+ Args:
76
+ raw_dir: Root directory of the raw dataset.
77
+ out_dir: Output directory for processed patches.
78
+ split: One of 'train', 'val', 'test'.
79
+ patch_size: Size of each square patch.
80
+
81
+ Returns:
82
+ Number of patch triplets generated.
83
+ """
84
+ # TODO: Implement processing pipeline
85
+ raise NotImplementedError("Split processing not yet implemented")
86
+
87
+
88
+ def preprocess_dataset(
89
+ dataset: str,
90
+ raw_dir: Path,
91
+ out_dir: Path,
92
+ patch_size: int = 256,
93
+ ) -> None:
94
+ """Run full preprocessing pipeline for a dataset.
95
+
96
+ Args:
97
+ dataset: Dataset name ('levir-cd' or 'whu-cd').
98
+ raw_dir: Directory containing raw downloaded data.
99
+ out_dir: Output directory for processed patches.
100
+ patch_size: Size of each square patch.
101
+ """
102
+ logger.info("Preprocessing %s: %s -> %s", dataset, raw_dir, out_dir)
103
+ out_dir.mkdir(parents=True, exist_ok=True)
104
+
105
+ for split in ["train", "val", "test"]:
106
+ count = process_split(raw_dir, out_dir, split, patch_size)
107
+ logger.info(" %s: %d patch triplets", split, count)
108
+
109
+
110
+ def main() -> None:
111
+ """CLI entry point for dataset download and preprocessing."""
112
+ parser = argparse.ArgumentParser(description="Download and preprocess change detection datasets")
113
+ parser.add_argument("--dataset", type=str, default="levir-cd", choices=["levir-cd", "whu-cd"])
114
+ parser.add_argument("--raw_dir", type=Path, default=Path("./raw_data"))
115
+ parser.add_argument("--out_dir", type=Path, default=Path("./processed_data"))
116
+ parser.add_argument("--patch_size", type=int, default=256)
117
+ parser.add_argument("--skip_download", action="store_true", help="Skip download, only preprocess")
118
+ args = parser.parse_args()
119
+
120
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
121
+
122
+ if not args.skip_download:
123
+ if args.dataset == "levir-cd":
124
+ download_levir_cd(args.raw_dir)
125
+ elif args.dataset == "whu-cd":
126
+ download_whu_cd(args.raw_dir)
127
+
128
+ preprocess_dataset(args.dataset, args.raw_dir, args.out_dir, args.patch_size)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
evaluate.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation script for change detection models.
2
+
3
+ Runs a trained model on the test set, computes all metrics, and generates
4
+ visualization outputs.
5
+
6
+ Usage:
7
+ python evaluate.py --config configs/config.yaml --checkpoint checkpoints/unet_pp_best.pth
8
+ """
9
+
10
+ import argparse
11
+ import logging
12
+ from pathlib import Path
13
+ from typing import Any, Dict
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.data import DataLoader
18
+ from tqdm import tqdm
19
+ import yaml
20
+
21
+ from data.dataset import ChangeDetectionDataset
22
+ from models import get_model
23
+ from utils.metrics import ConfusionMatrix
24
+ from utils.visualization import plot_prediction
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def evaluate(
30
+ model: nn.Module,
31
+ loader: DataLoader,
32
+ device: torch.device,
33
+ threshold: float = 0.5,
34
+ output_dir: Path = Path("./outputs"),
35
+ max_vis: int = 20,
36
+ ) -> Dict[str, float]:
37
+ """Evaluate model on the full test set.
38
+
39
+ Args:
40
+ model: Trained change detection model.
41
+ loader: Test DataLoader.
42
+ device: Target device.
43
+ threshold: Binarization threshold for predictions.
44
+ output_dir: Directory to save visualization outputs.
45
+ max_vis: Maximum number of sample predictions to save.
46
+
47
+ Returns:
48
+ Dict of metric name -> value.
49
+ """
50
+ model.eval()
51
+ cm = ConfusionMatrix()
52
+ vis_dir = output_dir / "visualizations"
53
+ vis_dir.mkdir(parents=True, exist_ok=True)
54
+ vis_count = 0
55
+
56
+ with torch.no_grad():
57
+ for batch in tqdm(loader, desc="Evaluating"):
58
+ img_a = batch["A"].to(device)
59
+ img_b = batch["B"].to(device)
60
+ mask = batch["mask"].to(device)
61
+
62
+ logits = model(img_a, img_b)
63
+ preds = (torch.sigmoid(logits) > threshold).float()
64
+ cm.update(preds, mask)
65
+
66
+ # Save sample visualizations
67
+ if vis_count < max_vis:
68
+ for i in range(min(img_a.size(0), max_vis - vis_count)):
69
+ plot_prediction(
70
+ img_a[i], img_b[i], mask[i], preds[i],
71
+ save_path=vis_dir / f"sample_{vis_count:04d}.png",
72
+ )
73
+ vis_count += 1
74
+
75
+ metrics = cm.compute()
76
+ return metrics
77
+
78
+
79
+ def main() -> None:
80
+ """Main evaluation entry point."""
81
+ parser = argparse.ArgumentParser(description="Evaluate change detection model")
82
+ parser.add_argument("--config", type=Path, default=Path("configs/config.yaml"))
83
+ parser.add_argument("--checkpoint", type=Path, required=True)
84
+ parser.add_argument("--model", type=str, default=None, help="Override model name")
85
+ parser.add_argument("--threshold", type=float, default=None)
86
+ args = parser.parse_args()
87
+
88
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
89
+
90
+ with open(args.config, "r") as f:
91
+ config = yaml.safe_load(f)
92
+
93
+ model_name = args.model or config["model"]["name"]
94
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+ threshold = args.threshold or config.get("evaluation", {}).get("threshold", 0.5)
96
+
97
+ # Resolve paths
98
+ colab = config.get("colab", {})
99
+ if colab.get("enabled", False):
100
+ data_dir = Path(colab["data_dir"])
101
+ output_dir = Path(colab["output_dir"])
102
+ else:
103
+ data_dir = Path(config["paths"]["processed_data"])
104
+ output_dir = Path(config["paths"]["output_dir"])
105
+
106
+ # Model
107
+ model = get_model(model_name, config).to(device)
108
+ ckpt = torch.load(args.checkpoint, map_location=device)
109
+ model.load_state_dict(ckpt["model_state_dict"])
110
+ logger.info("Loaded checkpoint: %s (epoch %d, F1 %.4f)",
111
+ args.checkpoint, ckpt.get("epoch", -1), ckpt.get("best_f1", -1))
112
+
113
+ # Test data
114
+ ds_cfg = config.get("dataset", {})
115
+ test_ds = ChangeDetectionDataset(data_dir / "test", split="test", config=config)
116
+ test_loader = DataLoader(
117
+ test_ds, batch_size=8, shuffle=False,
118
+ num_workers=ds_cfg.get("num_workers", 4),
119
+ pin_memory=ds_cfg.get("pin_memory", True),
120
+ )
121
+
122
+ # Evaluate
123
+ metrics = evaluate(model, test_loader, device, threshold, output_dir)
124
+
125
+ # Print results
126
+ logger.info("=" * 50)
127
+ logger.info("TEST SET RESULTS — %s", model_name)
128
+ logger.info("=" * 50)
129
+ for name, value in metrics.items():
130
+ logger.info(" %-12s: %.4f", name.upper(), value)
131
+ logger.info("=" * 50)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ main()
inference.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run inference on arbitrary before/after image pairs.
2
+
3
+ Loads a trained change detection model and produces binary change masks
4
+ for new satellite image pairs.
5
+
6
+ Usage:
7
+ python inference.py --before path/to/before.png --after path/to/after.png \
8
+ --model changeformer --checkpoint checkpoints/changeformer_best.pth
9
+ """
10
+
11
+ import argparse
12
+ import logging
13
+ from pathlib import Path
14
+ from typing import Any, Dict, Tuple
15
+
16
+ import cv2
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import yaml
22
+
23
+ from data.dataset import IMAGENET_MEAN, IMAGENET_STD
24
+ from models import get_model
25
+ from utils.visualization import overlay_changes, plot_prediction
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def preprocess_image(
31
+ image_path: Path,
32
+ patch_size: int = 256,
33
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
34
+ """Load and preprocess a single image for inference.
35
+
36
+ Reads the image, pads to a multiple of patch_size, and applies
37
+ ImageNet normalization.
38
+
39
+ Args:
40
+ image_path: Path to the input image.
41
+ patch_size: Patch size the model expects.
42
+
43
+ Returns:
44
+ Tuple of (preprocessed tensor [1, 3, H, W], original (H, W)).
45
+ """
46
+ img = cv2.imread(str(image_path), cv2.IMREAD_COLOR)
47
+ if img is None:
48
+ raise FileNotFoundError(f"Could not read image: {image_path}")
49
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
50
+ orig_h, orig_w = img.shape[:2]
51
+
52
+ # Pad to multiple of patch_size
53
+ pad_h = (patch_size - orig_h % patch_size) % patch_size
54
+ pad_w = (patch_size - orig_w % patch_size) % patch_size
55
+ if pad_h > 0 or pad_w > 0:
56
+ img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
57
+
58
+ # Normalize
59
+ img = img.astype(np.float32) / 255.0
60
+ mean = np.array(IMAGENET_MEAN, dtype=np.float32)
61
+ std = np.array(IMAGENET_STD, dtype=np.float32)
62
+ img = (img - mean) / std
63
+
64
+ # HWC -> CHW, add batch dim
65
+ tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()
66
+ return tensor, (orig_h, orig_w)
67
+
68
+
69
+ def sliding_window_inference(
70
+ model: nn.Module,
71
+ img_a: torch.Tensor,
72
+ img_b: torch.Tensor,
73
+ patch_size: int = 256,
74
+ device: torch.device = torch.device("cpu"),
75
+ ) -> torch.Tensor:
76
+ """Run inference using sliding window for large images.
77
+
78
+ Splits images into non-overlapping patches, runs model on each,
79
+ and stitches results back together.
80
+
81
+ Args:
82
+ model: Trained change detection model.
83
+ img_a: Before image tensor [1, 3, H, W].
84
+ img_b: After image tensor [1, 3, H, W].
85
+ patch_size: Size of each patch.
86
+ device: Inference device.
87
+
88
+ Returns:
89
+ Probability map [1, 1, H, W] (after sigmoid).
90
+ """
91
+ _, _, h, w = img_a.shape
92
+ output = torch.zeros(1, 1, h, w, device="cpu")
93
+
94
+ model.eval()
95
+ with torch.no_grad():
96
+ for y in range(0, h, patch_size):
97
+ for x in range(0, w, patch_size):
98
+ patch_a = img_a[:, :, y:y + patch_size, x:x + patch_size].to(device)
99
+ patch_b = img_b[:, :, y:y + patch_size, x:x + patch_size].to(device)
100
+
101
+ logits = model(patch_a, patch_b)
102
+ probs = torch.sigmoid(logits).cpu()
103
+ output[:, :, y:y + patch_size, x:x + patch_size] = probs
104
+
105
+ return output
106
+
107
+
108
+ def save_change_mask(
109
+ mask: np.ndarray,
110
+ save_path: Path,
111
+ threshold: float = 0.5,
112
+ ) -> None:
113
+ """Save binary change mask as an image.
114
+
115
+ Args:
116
+ mask: Probability map [H, W] with values in [0, 1].
117
+ save_path: Output file path.
118
+ threshold: Binarization threshold.
119
+ """
120
+ binary = (mask > threshold).astype(np.uint8) * 255
121
+ save_path.parent.mkdir(parents=True, exist_ok=True)
122
+ cv2.imwrite(str(save_path), binary)
123
+ logger.info("Saved change mask: %s", save_path)
124
+
125
+
126
+ def main() -> None:
127
+ """Main inference entry point."""
128
+ parser = argparse.ArgumentParser(description="Run change detection inference")
129
+ parser.add_argument("--before", type=Path, required=True, help="Path to before image")
130
+ parser.add_argument("--after", type=Path, required=True, help="Path to after image")
131
+ parser.add_argument("--model", type=str, default=None, help="Model name")
132
+ parser.add_argument("--checkpoint", type=Path, required=True, help="Path to model checkpoint")
133
+ parser.add_argument("--config", type=Path, default=Path("configs/config.yaml"))
134
+ parser.add_argument("--output", type=Path, default=Path("outputs/inference"))
135
+ parser.add_argument("--threshold", type=float, default=0.5)
136
+ args = parser.parse_args()
137
+
138
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
139
+
140
+ # Load config
141
+ with open(args.config, "r") as f:
142
+ config = yaml.safe_load(f)
143
+
144
+ model_name = args.model or config["model"]["name"]
145
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
+ patch_size = config.get("dataset", {}).get("patch_size", 256)
147
+
148
+ # Load model
149
+ model = get_model(model_name, config).to(device)
150
+ ckpt = torch.load(args.checkpoint, map_location=device)
151
+ model.load_state_dict(ckpt["model_state_dict"])
152
+ logger.info("Loaded model '%s' from %s", model_name, args.checkpoint)
153
+
154
+ # Preprocess images
155
+ img_a, (orig_h, orig_w) = preprocess_image(args.before, patch_size)
156
+ img_b, _ = preprocess_image(args.after, patch_size)
157
+
158
+ # Run inference
159
+ prob_map = sliding_window_inference(model, img_a, img_b, patch_size, device)
160
+
161
+ # Crop back to original size and save
162
+ prob_map = prob_map[:, :, :orig_h, :orig_w]
163
+ mask_np = prob_map.squeeze().numpy()
164
+
165
+ args.output.mkdir(parents=True, exist_ok=True)
166
+ save_change_mask(mask_np, args.output / "change_mask.png", args.threshold)
167
+
168
+ # Save overlay visualization
169
+ overlay = overlay_changes(img_b.squeeze()[:, :orig_h, :orig_w], prob_map.squeeze(0))
170
+ overlay_uint8 = (overlay * 255).astype(np.uint8)
171
+ cv2.imwrite(str(args.output / "overlay.png"), cv2.cvtColor(overlay_uint8, cv2.COLOR_RGB2BGR))
172
+ logger.info("Saved overlay: %s", args.output / "overlay.png")
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()
models/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model factory for change detection models.
2
+
3
+ Provides a unified interface to instantiate any supported model by name.
4
+ """
5
+
6
+ from typing import Any, Dict
7
+
8
+ import torch.nn as nn
9
+
10
+ from .changeformer import ChangeFormer
11
+ from .siamese_cnn import SiameseCNN
12
+ from .unet_pp import UNetPPChangeDetection
13
+
14
+ _MODEL_REGISTRY: Dict[str, type] = {
15
+ "siamese_cnn": SiameseCNN,
16
+ "unet_pp": UNetPPChangeDetection,
17
+ "changeformer": ChangeFormer,
18
+ }
19
+
20
+
21
+ def get_model(model_name: str, config: Dict[str, Any]) -> nn.Module:
22
+ """Instantiate a change detection model by name.
23
+
24
+ Args:
25
+ model_name: One of 'siamese_cnn', 'unet_pp', 'changeformer'.
26
+ config: Full config dict; model-specific section is extracted internally.
27
+
28
+ Returns:
29
+ Initialized model (nn.Module).
30
+
31
+ Raises:
32
+ ValueError: If model_name is not recognized.
33
+ """
34
+ if model_name not in _MODEL_REGISTRY:
35
+ raise ValueError(f"Unknown model '{model_name}'. Choose from: {list(_MODEL_REGISTRY.keys())}")
36
+
37
+ model_cls = _MODEL_REGISTRY[model_name]
38
+ model_config = config.get(model_name, {})
39
+ return model_cls(**model_config)
models/changeformer.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ChangeFormer — Transformer-based change detection model.
2
+
3
+ Implements a hierarchical vision transformer (MiT-B1 style) with shared-weight
4
+ Siamese encoder and MLP decoder for change detection. Based on:
5
+ "A Transformer-Based Siamese Network for Change Detection" (arXiv:2201.01293).
6
+ """
7
+
8
+ from typing import List, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+
16
+ class OverlapPatchEmbed(nn.Module):
17
+ """Overlapping patch embedding for hierarchical feature extraction.
18
+
19
+ Args:
20
+ in_channels: Number of input channels.
21
+ embed_dim: Embedding dimension.
22
+ patch_size: Patch size for convolution.
23
+ stride: Stride for convolution.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ in_channels: int = 3,
29
+ embed_dim: int = 64,
30
+ patch_size: int = 7,
31
+ stride: int = 4,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.proj = nn.Conv2d(
35
+ in_channels, embed_dim,
36
+ kernel_size=patch_size, stride=stride,
37
+ padding=patch_size // 2,
38
+ )
39
+ self.norm = nn.LayerNorm(embed_dim)
40
+
41
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
42
+ """Forward pass.
43
+
44
+ Args:
45
+ x: Input tensor [B, C, H, W].
46
+
47
+ Returns:
48
+ Tuple of (tokens [B, N, D], height, width).
49
+ """
50
+ x = self.proj(x)
51
+ _, _, h, w = x.shape
52
+ x = rearrange(x, "b c h w -> b (h w) c")
53
+ x = self.norm(x)
54
+ return x, h, w
55
+
56
+
57
+ class EfficientSelfAttention(nn.Module):
58
+ """Efficient self-attention with spatial reduction.
59
+
60
+ Args:
61
+ dim: Input dimension.
62
+ num_heads: Number of attention heads.
63
+ sr_ratio: Spatial reduction ratio.
64
+ """
65
+
66
+ def __init__(self, dim: int, num_heads: int = 1, sr_ratio: int = 8) -> None:
67
+ super().__init__()
68
+ self.num_heads = num_heads
69
+ self.head_dim = dim // num_heads
70
+ self.scale = self.head_dim ** -0.5
71
+
72
+ self.q = nn.Linear(dim, dim)
73
+ self.kv = nn.Linear(dim, dim * 2)
74
+ self.proj = nn.Linear(dim, dim)
75
+
76
+ # Spatial reduction
77
+ self.sr_ratio = sr_ratio
78
+ if sr_ratio > 1:
79
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
80
+ self.sr_norm = nn.LayerNorm(dim)
81
+
82
+ def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
83
+ """Forward pass.
84
+
85
+ Args:
86
+ x: Input tokens [B, N, C].
87
+ h: Feature map height.
88
+ w: Feature map width.
89
+
90
+ Returns:
91
+ Output tokens [B, N, C].
92
+ """
93
+ b, n, c = x.shape
94
+ q = self.q(x).reshape(b, n, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
95
+
96
+ if self.sr_ratio > 1:
97
+ x_ = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
98
+ x_ = self.sr(x_)
99
+ x_ = rearrange(x_, "b c h w -> b (h w) c")
100
+ x_ = self.sr_norm(x_)
101
+ else:
102
+ x_ = x
103
+
104
+ kv = self.kv(x_).reshape(b, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
105
+ k, v = kv[0], kv[1]
106
+
107
+ attn = (q @ k.transpose(-2, -1)) * self.scale
108
+ attn = attn.softmax(dim=-1)
109
+ out = (attn @ v).transpose(1, 2).reshape(b, n, c)
110
+ out = self.proj(out)
111
+ return out
112
+
113
+
114
+ class MixFFN(nn.Module):
115
+ """Mix Feed-Forward Network with depthwise convolution.
116
+
117
+ Args:
118
+ dim: Input/output dimension.
119
+ mlp_ratio: Expansion ratio for hidden dimension.
120
+ """
121
+
122
+ def __init__(self, dim: int, mlp_ratio: int = 4) -> None:
123
+ super().__init__()
124
+ hidden = dim * mlp_ratio
125
+ self.fc1 = nn.Linear(dim, hidden)
126
+ self.dwconv = nn.Conv2d(hidden, hidden, 3, 1, 1, groups=hidden)
127
+ self.fc2 = nn.Linear(hidden, dim)
128
+ self.act = nn.GELU()
129
+
130
+ def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
131
+ """Forward pass.
132
+
133
+ Args:
134
+ x: Input tokens [B, N, C].
135
+ h: Feature map height.
136
+ w: Feature map width.
137
+
138
+ Returns:
139
+ Output tokens [B, N, C].
140
+ """
141
+ x = self.fc1(x)
142
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
143
+ x = self.act(self.dwconv(x))
144
+ x = rearrange(x, "b c h w -> b (h w) c")
145
+ x = self.fc2(x)
146
+ return x
147
+
148
+
149
+ class TransformerBlock(nn.Module):
150
+ """Single transformer block with efficient attention and MixFFN.
151
+
152
+ Args:
153
+ dim: Feature dimension.
154
+ num_heads: Number of attention heads.
155
+ mlp_ratio: MLP expansion ratio.
156
+ sr_ratio: Spatial reduction ratio for attention.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ dim: int,
162
+ num_heads: int = 1,
163
+ mlp_ratio: int = 4,
164
+ sr_ratio: int = 8,
165
+ ) -> None:
166
+ super().__init__()
167
+ self.norm1 = nn.LayerNorm(dim)
168
+ self.attn = EfficientSelfAttention(dim, num_heads, sr_ratio)
169
+ self.norm2 = nn.LayerNorm(dim)
170
+ self.ffn = MixFFN(dim, mlp_ratio)
171
+
172
+ def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
173
+ """Forward pass.
174
+
175
+ Args:
176
+ x: Input tokens [B, N, C].
177
+ h: Feature map height.
178
+ w: Feature map width.
179
+
180
+ Returns:
181
+ Output tokens [B, N, C].
182
+ """
183
+ x = x + self.attn(self.norm1(x), h, w)
184
+ x = x + self.ffn(self.norm2(x), h, w)
185
+ return x
186
+
187
+
188
+ class MiTEncoder(nn.Module):
189
+ """Mix Transformer (MiT) encoder — hierarchical vision transformer.
190
+
191
+ Args:
192
+ embed_dims: Embedding dimensions at each stage.
193
+ num_heads: Number of attention heads at each stage.
194
+ mlp_ratios: MLP expansion ratios at each stage.
195
+ depths: Number of transformer blocks at each stage.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ embed_dims: List[int] = [64, 128, 320, 512],
201
+ num_heads: List[int] = [1, 2, 5, 8],
202
+ mlp_ratios: List[int] = [8, 8, 4, 4],
203
+ depths: List[int] = [2, 2, 2, 2],
204
+ ) -> None:
205
+ super().__init__()
206
+ self.num_stages = len(embed_dims)
207
+
208
+ sr_ratios = [8, 4, 2, 1]
209
+ patch_sizes = [7, 3, 3, 3]
210
+ strides = [4, 2, 2, 2]
211
+
212
+ self.patch_embeds = nn.ModuleList()
213
+ self.blocks = nn.ModuleList()
214
+ self.norms = nn.ModuleList()
215
+
216
+ for i in range(self.num_stages):
217
+ in_ch = 3 if i == 0 else embed_dims[i - 1]
218
+ self.patch_embeds.append(
219
+ OverlapPatchEmbed(in_ch, embed_dims[i], patch_sizes[i], strides[i])
220
+ )
221
+ self.blocks.append(
222
+ nn.ModuleList([
223
+ TransformerBlock(embed_dims[i], num_heads[i], mlp_ratios[i], sr_ratios[i])
224
+ for _ in range(depths[i])
225
+ ])
226
+ )
227
+ self.norms.append(nn.LayerNorm(embed_dims[i]))
228
+
229
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
230
+ """Extract hierarchical features.
231
+
232
+ Args:
233
+ x: Input image [B, 3, H, W].
234
+
235
+ Returns:
236
+ List of feature maps at each stage [B, C_i, H_i, W_i].
237
+ """
238
+ features = []
239
+ for i in range(self.num_stages):
240
+ x, h, w = self.patch_embeds[i](x)
241
+ for blk in self.blocks[i]:
242
+ x = blk(x, h, w)
243
+ x = self.norms[i](x)
244
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
245
+ features.append(x)
246
+ return features
247
+
248
+
249
+ class MLPDecoder(nn.Module):
250
+ """MLP-based decoder that fuses multi-scale difference features.
251
+
252
+ Args:
253
+ embed_dims: Embedding dimensions from each encoder stage.
254
+ out_channels: Number of output channels (1 for binary change mask).
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ embed_dims: List[int] = [64, 128, 320, 512],
260
+ out_channels: int = 1,
261
+ ) -> None:
262
+ super().__init__()
263
+ unified_dim = embed_dims[0]
264
+
265
+ self.linear_projections = nn.ModuleList([
266
+ nn.Conv2d(dim, unified_dim, kernel_size=1)
267
+ for dim in embed_dims
268
+ ])
269
+
270
+ self.fuse = nn.Sequential(
271
+ nn.Conv2d(unified_dim * len(embed_dims), unified_dim, kernel_size=1),
272
+ nn.BatchNorm2d(unified_dim),
273
+ nn.ReLU(inplace=True),
274
+ )
275
+ self.head = nn.Conv2d(unified_dim, out_channels, kernel_size=1)
276
+
277
+ def forward(self, features: List[torch.Tensor], target_size: Tuple[int, int]) -> torch.Tensor:
278
+ """Forward pass.
279
+
280
+ Args:
281
+ features: List of difference feature maps.
282
+ target_size: (H, W) of the desired output.
283
+
284
+ Returns:
285
+ Logits [B, 1, H, W].
286
+ """
287
+ projected = []
288
+ for i, (feat, proj) in enumerate(zip(features, self.linear_projections)):
289
+ p = proj(feat)
290
+ p = F.interpolate(p, size=target_size, mode="bilinear", align_corners=False)
291
+ projected.append(p)
292
+
293
+ fused = self.fuse(torch.cat(projected, dim=1))
294
+ out = self.head(fused)
295
+ return out
296
+
297
+
298
+ class ChangeFormer(nn.Module):
299
+ """ChangeFormer: Transformer-based Siamese network for change detection.
300
+
301
+ Args:
302
+ embed_dims: Embedding dims at each hierarchical stage.
303
+ num_heads: Attention heads at each stage.
304
+ mlp_ratios: MLP expansion ratios at each stage.
305
+ depths: Transformer block counts at each stage.
306
+ pretrained_backbone: Whether to load pretrained MiT weights.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ embed_dims: List[int] = [64, 128, 320, 512],
312
+ num_heads: List[int] = [1, 2, 5, 8],
313
+ mlp_ratios: List[int] = [8, 8, 4, 4],
314
+ depths: List[int] = [2, 2, 2, 2],
315
+ pretrained_backbone: bool = True,
316
+ ) -> None:
317
+ super().__init__()
318
+
319
+ # Shared Siamese encoder
320
+ self.encoder = MiTEncoder(embed_dims, num_heads, mlp_ratios, depths)
321
+
322
+ # MLP decoder
323
+ self.decoder = MLPDecoder(embed_dims, out_channels=1)
324
+
325
+ # TODO: Load pretrained MiT-B1 weights if pretrained_backbone is True
326
+
327
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
328
+ """Forward pass.
329
+
330
+ Args:
331
+ x1: Before image [B, 3, 256, 256].
332
+ x2: After image [B, 3, 256, 256].
333
+
334
+ Returns:
335
+ Raw logits [B, 1, 256, 256].
336
+ """
337
+ # Extract hierarchical features
338
+ feats_1 = self.encoder(x1)
339
+ feats_2 = self.encoder(x2)
340
+
341
+ # Compute difference at each scale
342
+ diff_feats = [torch.abs(f1 - f2) for f1, f2 in zip(feats_1, feats_2)]
343
+
344
+ # Decode to change mask
345
+ target_size = (x1.shape[2], x1.shape[3])
346
+ out = self.decoder(diff_feats, target_size)
347
+ return out
348
+
349
+
350
+ if __name__ == "__main__":
351
+ # Quick sanity check
352
+ model = ChangeFormer(pretrained_backbone=False)
353
+ x1 = torch.randn(1, 3, 256, 256)
354
+ x2 = torch.randn(1, 3, 256, 256)
355
+ out = model(x1, x2)
356
+ print(f"Input: {x1.shape}, Output: {out.shape}")
357
+ assert out.shape == (1, 1, 256, 256), f"Unexpected shape: {out.shape}"
358
+ print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
models/siamese_cnn.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Siamese CNN baseline for change detection.
2
+
3
+ Architecture: Shared-weight ResNet18 backbone extracts features from both
4
+ images. Feature difference is decoded via transposed convolutions to produce
5
+ a binary change mask.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision.models as models
11
+
12
+
13
+ class SiameseCNN(nn.Module):
14
+ """Siamese CNN with shared ResNet18 encoder and transposed-conv decoder.
15
+
16
+ Args:
17
+ backbone: Backbone architecture name (default: 'resnet18').
18
+ pretrained: Whether to use ImageNet-pretrained weights.
19
+ """
20
+
21
+ def __init__(self, backbone: str = "resnet18", pretrained: bool = True) -> None:
22
+ super().__init__()
23
+
24
+ # Shared encoder
25
+ resnet = getattr(models, backbone)(
26
+ weights=models.ResNet18_Weights.DEFAULT if pretrained else None
27
+ )
28
+ # Remove avgpool and fc — keep feature extraction layers
29
+ self.encoder = nn.Sequential(
30
+ resnet.conv1,
31
+ resnet.bn1,
32
+ resnet.relu,
33
+ resnet.maxpool,
34
+ resnet.layer1, # 64 channels
35
+ resnet.layer2, # 128 channels
36
+ resnet.layer3, # 256 channels
37
+ resnet.layer4, # 512 channels
38
+ )
39
+
40
+ # Decoder: upsample difference features back to input resolution
41
+ self.decoder = nn.Sequential(
42
+ nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
43
+ nn.BatchNorm2d(256),
44
+ nn.ReLU(inplace=True),
45
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
46
+ nn.BatchNorm2d(128),
47
+ nn.ReLU(inplace=True),
48
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
49
+ nn.BatchNorm2d(64),
50
+ nn.ReLU(inplace=True),
51
+ nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
52
+ nn.BatchNorm2d(32),
53
+ nn.ReLU(inplace=True),
54
+ nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
55
+ )
56
+
57
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
58
+ """Forward pass.
59
+
60
+ Args:
61
+ x1: Before image [B, 3, 256, 256].
62
+ x2: After image [B, 3, 256, 256].
63
+
64
+ Returns:
65
+ Raw logits [B, 1, 256, 256].
66
+ """
67
+ f1 = self.encoder(x1)
68
+ f2 = self.encoder(x2)
69
+
70
+ # Feature difference
71
+ diff = torch.abs(f1 - f2)
72
+
73
+ # Decode to change mask
74
+ out = self.decoder(diff)
75
+ return out
76
+
77
+
78
+ if __name__ == "__main__":
79
+ # Quick sanity check
80
+ model = SiameseCNN(pretrained=False)
81
+ x1 = torch.randn(2, 3, 256, 256)
82
+ x2 = torch.randn(2, 3, 256, 256)
83
+ out = model(x1, x2)
84
+ print(f"Input: {x1.shape}, Output: {out.shape}")
85
+ assert out.shape == (2, 1, 256, 256), f"Unexpected shape: {out.shape}"
models/unet_pp.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """UNet++ (Nested U-Net) for change detection.
2
+
3
+ Uses a shared ResNet34 encoder from segmentation-models-pytorch. Features from
4
+ both temporal images are differenced and decoded through nested skip connections.
5
+ Optionally supports deep supervision.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import segmentation_models_pytorch as smp
11
+
12
+
13
+ class UNetPPChangeDetection(nn.Module):
14
+ """UNet++ adapted for bitemporal change detection.
15
+
16
+ A shared encoder processes both images. The absolute difference of
17
+ encoder features is fed into the UNet++ decoder.
18
+
19
+ Args:
20
+ encoder_name: Encoder backbone (default: 'resnet34').
21
+ pretrained: Use ImageNet-pretrained encoder weights.
22
+ deep_supervision: Enable deep supervision outputs.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ encoder_name: str = "resnet34",
28
+ pretrained: bool = True,
29
+ deep_supervision: bool = False,
30
+ ) -> None:
31
+ super().__init__()
32
+ self.deep_supervision = deep_supervision
33
+
34
+ # Shared encoder via SMP
35
+ encoder_weights = "imagenet" if pretrained else None
36
+ self.base_model = smp.UnetPlusPlus(
37
+ encoder_name=encoder_name,
38
+ encoder_weights=encoder_weights,
39
+ in_channels=3,
40
+ classes=1,
41
+ )
42
+
43
+ # We'll use the encoder and decoder separately
44
+ self.encoder = self.base_model.encoder
45
+ self.decoder = self.base_model.decoder
46
+ self.segmentation_head = self.base_model.segmentation_head
47
+
48
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
49
+ """Forward pass.
50
+
51
+ Args:
52
+ x1: Before image [B, 3, 256, 256].
53
+ x2: After image [B, 3, 256, 256].
54
+
55
+ Returns:
56
+ Raw logits [B, 1, 256, 256].
57
+ """
58
+ # Extract multi-scale features from both images
59
+ features_1 = self.encoder(x1)
60
+ features_2 = self.encoder(x2)
61
+
62
+ # Compute absolute difference at each scale
63
+ diff_features = [torch.abs(f1 - f2) for f1, f2 in zip(features_1, features_2)]
64
+
65
+ # Decode
66
+ decoder_output = self.decoder(*diff_features)
67
+ out = self.segmentation_head(decoder_output)
68
+ return out
69
+
70
+
71
+ if __name__ == "__main__":
72
+ # Quick sanity check
73
+ model = UNetPPChangeDetection(pretrained=False)
74
+ x1 = torch.randn(2, 3, 256, 256)
75
+ x2 = torch.randn(2, 3, 256, 256)
76
+ out = model(x1, x2)
77
+ print(f"Input: {x1.shape}, Output: {out.shape}")
78
+ assert out.shape == (2, 1, 256, 256), f"Unexpected shape: {out.shape}"
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.2
2
+ torchvision==0.16.2
3
+ segmentation-models-pytorch==0.3.3
4
+ timm==0.9.12
5
+ einops==0.7.0
6
+ albumentations==1.3.1
7
+ opencv-python-headless==4.9.0.80
8
+ scikit-learn==1.4.0
9
+ matplotlib==3.8.2
10
+ numpy==1.26.3
11
+ Pillow==10.2.0
12
+ PyYAML==6.0.1
13
+ tqdm==4.66.1
14
+ tensorboard==2.15.1
15
+ gradio==4.14.0
16
+ gdown==5.1.0
setup_colab.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Google Colab setup script.
2
+
3
+ Handles Drive mounting, GPU verification, dependency installation,
4
+ and path configuration. Run this at the start of every Colab session.
5
+
6
+ Usage (in Colab cell):
7
+ !python setup_colab.py
8
+ # Or import and call:
9
+ from setup_colab import setup
10
+ setup()
11
+ """
12
+
13
+ import logging
14
+ import os
15
+ import subprocess
16
+ import sys
17
+ from pathlib import Path
18
+ from typing import Dict, Optional
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def mount_drive() -> None:
24
+ """Mount Google Drive at /content/drive.
25
+
26
+ Skips if not running in Colab or already mounted.
27
+ """
28
+ if not is_colab():
29
+ logger.info("Not running in Colab — skipping Drive mount.")
30
+ return
31
+
32
+ if Path("/content/drive/MyDrive").exists():
33
+ logger.info("Google Drive already mounted.")
34
+ return
35
+
36
+ from google.colab import drive
37
+ drive.mount("/content/drive")
38
+ logger.info("Google Drive mounted successfully.")
39
+
40
+
41
+ def is_colab() -> bool:
42
+ """Check if running inside Google Colab.
43
+
44
+ Returns:
45
+ True if running in Colab environment.
46
+ """
47
+ try:
48
+ import google.colab # noqa: F401
49
+ return True
50
+ except ImportError:
51
+ return False
52
+
53
+
54
+ def check_gpu() -> Optional[str]:
55
+ """Check GPU availability and print device info.
56
+
57
+ Returns:
58
+ GPU name string, or None if no GPU available.
59
+ """
60
+ import torch
61
+
62
+ if not torch.cuda.is_available():
63
+ logger.warning("No GPU detected! Training will be very slow.")
64
+ return None
65
+
66
+ gpu_name = torch.cuda.get_device_name(0)
67
+ vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9
68
+ logger.info("GPU: %s (%.1f GB VRAM)", gpu_name, vram_gb)
69
+ return gpu_name
70
+
71
+
72
+ def detect_gpu_type() -> str:
73
+ """Detect GPU type for batch size selection.
74
+
75
+ Returns:
76
+ One of 'T4', 'V100', or 'default'.
77
+ """
78
+ import torch
79
+
80
+ if not torch.cuda.is_available():
81
+ return "default"
82
+
83
+ name = torch.cuda.get_device_name(0).upper()
84
+ if "T4" in name:
85
+ return "T4"
86
+ elif "V100" in name:
87
+ return "V100"
88
+ return "default"
89
+
90
+
91
+ def install_requirements() -> None:
92
+ """Install project dependencies from requirements.txt."""
93
+ req_path = Path("requirements.txt")
94
+ if not req_path.exists():
95
+ logger.warning("requirements.txt not found in %s", Path.cwd())
96
+ return
97
+
98
+ logger.info("Installing dependencies...")
99
+ subprocess.check_call([
100
+ sys.executable, "-m", "pip", "install", "-q", "-r", str(req_path)
101
+ ])
102
+ logger.info("Dependencies installed.")
103
+
104
+
105
+ def create_drive_dirs(drive_root: str = "/content/drive/MyDrive/change-detection") -> Dict[str, Path]:
106
+ """Create project directories on Google Drive.
107
+
108
+ Args:
109
+ drive_root: Root directory on Drive for this project.
110
+
111
+ Returns:
112
+ Dict mapping directory names to their paths.
113
+ """
114
+ dirs = {
115
+ "root": Path(drive_root),
116
+ "checkpoints": Path(drive_root) / "checkpoints",
117
+ "logs": Path(drive_root) / "logs",
118
+ "outputs": Path(drive_root) / "outputs",
119
+ "data": Path(drive_root) / "processed_data",
120
+ }
121
+
122
+ for name, path in dirs.items():
123
+ path.mkdir(parents=True, exist_ok=True)
124
+ logger.info(" %s: %s", name, path)
125
+
126
+ return dirs
127
+
128
+
129
+ def setup(
130
+ drive_root: str = "/content/drive/MyDrive/change-detection",
131
+ install_deps: bool = True,
132
+ ) -> Dict[str, Path]:
133
+ """Run full Colab setup.
134
+
135
+ Args:
136
+ drive_root: Root directory on Google Drive.
137
+ install_deps: Whether to install pip dependencies.
138
+
139
+ Returns:
140
+ Dict of project directory paths.
141
+ """
142
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
143
+
144
+ logger.info("=" * 60)
145
+ logger.info("Military Base Change Detection — Colab Setup")
146
+ logger.info("=" * 60)
147
+
148
+ # 1. Mount Drive
149
+ mount_drive()
150
+
151
+ # 2. Check GPU
152
+ gpu_name = check_gpu()
153
+ gpu_type = detect_gpu_type()
154
+ logger.info("GPU type for batch sizing: %s", gpu_type)
155
+
156
+ # 3. Install dependencies
157
+ if install_deps:
158
+ install_requirements()
159
+
160
+ # 4. Create Drive directories
161
+ logger.info("Creating project directories on Drive...")
162
+ dirs = create_drive_dirs(drive_root)
163
+
164
+ logger.info("=" * 60)
165
+ logger.info("Setup complete! Ready to train.")
166
+ logger.info("=" * 60)
167
+
168
+ return dirs
169
+
170
+
171
+ if __name__ == "__main__":
172
+ setup()
train.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main training script for change detection models.
2
+
3
+ Supports AMP, gradient clipping, early stopping, checkpoint saving to Google
4
+ Drive, and resume from checkpoint after Colab disconnects.
5
+
6
+ Usage:
7
+ python train.py --config configs/config.yaml --model unet_pp
8
+ python train.py --config configs/config.yaml --model changeformer --resume checkpoints/changeformer_last.pth
9
+ """
10
+
11
+ import argparse
12
+ import logging
13
+ import random
14
+ from pathlib import Path
15
+ from typing import Any, Dict, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.cuda.amp import GradScaler, autocast
21
+ from torch.optim import AdamW
22
+ from torch.optim.lr_scheduler import CosineAnnealingLR
23
+ from torch.utils.data import DataLoader
24
+ from torch.utils.tensorboard import SummaryWriter
25
+ from tqdm import tqdm
26
+ import yaml
27
+
28
+ from data.dataset import ChangeDetectionDataset
29
+ from models import get_model
30
+ from utils.losses import get_loss
31
+ from utils.metrics import ConfusionMatrix
32
+ from utils.visualization import plot_prediction
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def set_seed(seed: int) -> None:
38
+ """Set random seeds for reproducibility.
39
+
40
+ Args:
41
+ seed: Random seed value.
42
+ """
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed)
47
+ torch.backends.cudnn.deterministic = True
48
+ torch.backends.cudnn.benchmark = False
49
+
50
+
51
+ def detect_gpu_type() -> str:
52
+ """Detect the current GPU type for batch size selection.
53
+
54
+ Returns:
55
+ GPU type string ('T4', 'V100', or 'default').
56
+ """
57
+ if not torch.cuda.is_available():
58
+ return "default"
59
+ name = torch.cuda.get_device_name(0).upper()
60
+ if "T4" in name:
61
+ return "T4"
62
+ elif "V100" in name:
63
+ return "V100"
64
+ return "default"
65
+
66
+
67
+ def get_batch_size(config: Dict[str, Any], model_name: str) -> int:
68
+ """Get appropriate batch size based on GPU and model.
69
+
70
+ Args:
71
+ config: Full config dict.
72
+ model_name: Model name string.
73
+
74
+ Returns:
75
+ Batch size integer.
76
+ """
77
+ gpu_type = detect_gpu_type()
78
+ batch_sizes = config.get("batch_sizes", {}).get(model_name, {})
79
+ return batch_sizes.get(gpu_type, batch_sizes.get("default", 4))
80
+
81
+
82
+ def get_paths(config: Dict[str, Any]) -> Dict[str, Path]:
83
+ """Resolve paths based on whether running on Colab or locally.
84
+
85
+ Args:
86
+ config: Full config dict.
87
+
88
+ Returns:
89
+ Dict with keys: 'data', 'checkpoints', 'logs', 'outputs'.
90
+ """
91
+ if config.get("colab", {}).get("enabled", False):
92
+ colab = config["colab"]
93
+ return {
94
+ "data": Path(colab["data_dir"]),
95
+ "checkpoints": Path(colab["checkpoint_dir"]),
96
+ "logs": Path(colab["log_dir"]),
97
+ "outputs": Path(colab["output_dir"]),
98
+ }
99
+ else:
100
+ paths = config.get("paths", {})
101
+ return {
102
+ "data": Path(paths.get("processed_data", "./processed_data")),
103
+ "checkpoints": Path(paths.get("checkpoint_dir", "./checkpoints")),
104
+ "logs": Path(paths.get("log_dir", "./logs")),
105
+ "outputs": Path(paths.get("output_dir", "./outputs")),
106
+ }
107
+
108
+
109
+ def build_dataloaders(
110
+ config: Dict[str, Any],
111
+ data_dir: Path,
112
+ batch_size: int,
113
+ ) -> Tuple[DataLoader, DataLoader]:
114
+ """Create train and validation DataLoaders.
115
+
116
+ Args:
117
+ config: Full config dict.
118
+ data_dir: Path to processed dataset root.
119
+ batch_size: Batch size.
120
+
121
+ Returns:
122
+ Tuple of (train_loader, val_loader).
123
+ """
124
+ ds_cfg = config.get("dataset", {})
125
+ num_workers = ds_cfg.get("num_workers", 4)
126
+ pin_memory = ds_cfg.get("pin_memory", True)
127
+
128
+ train_ds = ChangeDetectionDataset(data_dir / "train", split="train", config=config)
129
+ val_ds = ChangeDetectionDataset(data_dir / "val", split="val", config=config)
130
+
131
+ train_loader = DataLoader(
132
+ train_ds, batch_size=batch_size, shuffle=True,
133
+ num_workers=num_workers, pin_memory=pin_memory, drop_last=True,
134
+ )
135
+ val_loader = DataLoader(
136
+ val_ds, batch_size=batch_size, shuffle=False,
137
+ num_workers=num_workers, pin_memory=pin_memory,
138
+ )
139
+ return train_loader, val_loader
140
+
141
+
142
+ def train_one_epoch(
143
+ model: nn.Module,
144
+ loader: DataLoader,
145
+ criterion: nn.Module,
146
+ optimizer: torch.optim.Optimizer,
147
+ scaler: GradScaler,
148
+ device: torch.device,
149
+ config: Dict[str, Any],
150
+ ) -> Tuple[float, Dict[str, float]]:
151
+ """Run one training epoch.
152
+
153
+ Args:
154
+ model: The change detection model.
155
+ loader: Training DataLoader.
156
+ criterion: Loss function.
157
+ optimizer: Optimizer.
158
+ scaler: GradScaler for AMP.
159
+ device: Target device.
160
+ config: Full config dict.
161
+
162
+ Returns:
163
+ Tuple of (average loss, metrics dict).
164
+ """
165
+ model.train()
166
+ running_loss = 0.0
167
+ cm = ConfusionMatrix()
168
+ train_cfg = config.get("training", {})
169
+ accum_steps = train_cfg.get("gradient_accumulation_steps", 1)
170
+ grad_clip = train_cfg.get("grad_clip_max_norm", 1.0)
171
+ threshold = config.get("evaluation", {}).get("threshold", 0.5)
172
+
173
+ optimizer.zero_grad()
174
+
175
+ for step, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
176
+ img_a = batch["A"].to(device)
177
+ img_b = batch["B"].to(device)
178
+ mask = batch["mask"].to(device)
179
+
180
+ with autocast():
181
+ logits = model(img_a, img_b)
182
+ loss = criterion(logits, mask) / accum_steps
183
+
184
+ scaler.scale(loss).backward()
185
+
186
+ if (step + 1) % accum_steps == 0:
187
+ scaler.unscale_(optimizer)
188
+ nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
189
+ scaler.step(optimizer)
190
+ scaler.update()
191
+ optimizer.zero_grad()
192
+
193
+ running_loss += loss.item() * accum_steps
194
+
195
+ # Metrics
196
+ with torch.no_grad():
197
+ preds = (torch.sigmoid(logits) > threshold).float()
198
+ cm.update(preds, mask)
199
+
200
+ avg_loss = running_loss / len(loader)
201
+ metrics = cm.compute()
202
+ return avg_loss, metrics
203
+
204
+
205
+ @torch.no_grad()
206
+ def validate(
207
+ model: nn.Module,
208
+ loader: DataLoader,
209
+ criterion: nn.Module,
210
+ device: torch.device,
211
+ threshold: float = 0.5,
212
+ ) -> Tuple[float, Dict[str, float]]:
213
+ """Run validation.
214
+
215
+ Args:
216
+ model: The change detection model.
217
+ loader: Validation DataLoader.
218
+ criterion: Loss function.
219
+ device: Target device.
220
+ threshold: Binarization threshold.
221
+
222
+ Returns:
223
+ Tuple of (average loss, metrics dict).
224
+ """
225
+ model.eval()
226
+ running_loss = 0.0
227
+ cm = ConfusionMatrix()
228
+
229
+ for batch in tqdm(loader, desc="Val", leave=False):
230
+ img_a = batch["A"].to(device)
231
+ img_b = batch["B"].to(device)
232
+ mask = batch["mask"].to(device)
233
+
234
+ logits = model(img_a, img_b)
235
+ loss = criterion(logits, mask)
236
+ running_loss += loss.item()
237
+
238
+ preds = (torch.sigmoid(logits) > threshold).float()
239
+ cm.update(preds, mask)
240
+
241
+ avg_loss = running_loss / len(loader)
242
+ metrics = cm.compute()
243
+ return avg_loss, metrics
244
+
245
+
246
+ def save_checkpoint(
247
+ model: nn.Module,
248
+ optimizer: torch.optim.Optimizer,
249
+ scheduler: Any,
250
+ scaler: GradScaler,
251
+ epoch: int,
252
+ best_f1: float,
253
+ save_path: Path,
254
+ ) -> None:
255
+ """Save a training checkpoint.
256
+
257
+ Args:
258
+ model: Model to save.
259
+ optimizer: Optimizer state.
260
+ scheduler: LR scheduler state.
261
+ scaler: GradScaler state.
262
+ epoch: Current epoch number.
263
+ best_f1: Best validation F1 so far.
264
+ save_path: Path to save the checkpoint.
265
+ """
266
+ save_path.parent.mkdir(parents=True, exist_ok=True)
267
+ torch.save({
268
+ "epoch": epoch,
269
+ "model_state_dict": model.state_dict(),
270
+ "optimizer_state_dict": optimizer.state_dict(),
271
+ "scheduler_state_dict": scheduler.state_dict(),
272
+ "scaler_state_dict": scaler.state_dict(),
273
+ "best_f1": best_f1,
274
+ }, save_path)
275
+ logger.info("Saved checkpoint: %s", save_path)
276
+
277
+
278
+ def load_checkpoint(
279
+ path: Path,
280
+ model: nn.Module,
281
+ optimizer: torch.optim.Optimizer,
282
+ scheduler: Any,
283
+ scaler: GradScaler,
284
+ device: torch.device,
285
+ ) -> Tuple[int, float]:
286
+ """Load a training checkpoint for resume.
287
+
288
+ Args:
289
+ path: Path to the checkpoint file.
290
+ model: Model to load weights into.
291
+ optimizer: Optimizer to load state into.
292
+ scheduler: Scheduler to load state into.
293
+ scaler: GradScaler to load state into.
294
+ device: Target device.
295
+
296
+ Returns:
297
+ Tuple of (start_epoch, best_f1).
298
+ """
299
+ ckpt = torch.load(path, map_location=device)
300
+ model.load_state_dict(ckpt["model_state_dict"])
301
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
302
+ scheduler.load_state_dict(ckpt["scheduler_state_dict"])
303
+ scaler.load_state_dict(ckpt["scaler_state_dict"])
304
+ logger.info("Resumed from epoch %d (best F1: %.4f)", ckpt["epoch"], ckpt["best_f1"])
305
+ return ckpt["epoch"], ckpt["best_f1"]
306
+
307
+
308
+ def main() -> None:
309
+ """Main training entry point."""
310
+ parser = argparse.ArgumentParser(description="Train change detection model")
311
+ parser.add_argument("--config", type=Path, default=Path("configs/config.yaml"))
312
+ parser.add_argument("--model", type=str, default=None, help="Override model name from config")
313
+ parser.add_argument("--resume", type=Path, default=None, help="Path to checkpoint for resume")
314
+ args = parser.parse_args()
315
+
316
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
317
+
318
+ # Load config
319
+ with open(args.config, "r") as f:
320
+ config = yaml.safe_load(f)
321
+
322
+ model_name = args.model or config["model"]["name"]
323
+ seed = config.get("project", {}).get("seed", 42)
324
+ set_seed(seed)
325
+
326
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
327
+ logger.info("Device: %s", device)
328
+
329
+ # Resolve paths
330
+ paths = get_paths(config)
331
+ for p in paths.values():
332
+ p.mkdir(parents=True, exist_ok=True)
333
+
334
+ # Model
335
+ model = get_model(model_name, config).to(device)
336
+ logger.info("Model: %s (%.1fM params)", model_name,
337
+ sum(p.numel() for p in model.parameters()) / 1e6)
338
+
339
+ # Data
340
+ batch_size = get_batch_size(config, model_name)
341
+ train_loader, val_loader = build_dataloaders(config, paths["data"], batch_size)
342
+
343
+ # Loss, optimizer, scheduler
344
+ criterion = get_loss(config)
345
+ lr = config.get("learning_rates", {}).get(model_name, config["training"]["learning_rate"])
346
+ epochs = config.get("epoch_counts", {}).get(model_name, config["training"]["epochs"])
347
+
348
+ optimizer = AdamW(model.parameters(), lr=lr, weight_decay=config["training"]["weight_decay"])
349
+ scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
350
+ scaler = GradScaler()
351
+
352
+ # TensorBoard
353
+ writer = SummaryWriter(log_dir=str(paths["logs"] / model_name))
354
+
355
+ # Resume
356
+ start_epoch = 0
357
+ best_f1 = 0.0
358
+ if args.resume and args.resume.exists():
359
+ start_epoch, best_f1 = load_checkpoint(
360
+ args.resume, model, optimizer, scheduler, scaler, device
361
+ )
362
+
363
+ # Early stopping state
364
+ es_cfg = config["training"]["early_stopping"]
365
+ patience = es_cfg.get("patience", 15)
366
+ patience_counter = 0
367
+ threshold = config.get("evaluation", {}).get("threshold", 0.5)
368
+
369
+ # Training loop
370
+ for epoch in range(start_epoch, epochs):
371
+ logger.info("Epoch %d/%d", epoch + 1, epochs)
372
+
373
+ train_loss, train_metrics = train_one_epoch(
374
+ model, train_loader, criterion, optimizer, scaler, device, config
375
+ )
376
+ val_loss, val_metrics = validate(model, val_loader, criterion, device, threshold)
377
+ scheduler.step()
378
+
379
+ # Log
380
+ writer.add_scalar("Loss/train", train_loss, epoch)
381
+ writer.add_scalar("Loss/val", val_loss, epoch)
382
+ for k, v in val_metrics.items():
383
+ writer.add_scalar(f"Val/{k}", v, epoch)
384
+
385
+ logger.info(
386
+ " Train Loss: %.4f | Val Loss: %.4f | Val F1: %.4f | Val IoU: %.4f",
387
+ train_loss, val_loss, val_metrics["f1"], val_metrics["iou"],
388
+ )
389
+
390
+ # Save last checkpoint (always)
391
+ save_checkpoint(
392
+ model, optimizer, scheduler, scaler, epoch + 1, best_f1,
393
+ paths["checkpoints"] / f"{model_name}_last.pth",
394
+ )
395
+
396
+ # Save best checkpoint
397
+ if val_metrics["f1"] > best_f1:
398
+ best_f1 = val_metrics["f1"]
399
+ patience_counter = 0
400
+ save_checkpoint(
401
+ model, optimizer, scheduler, scaler, epoch + 1, best_f1,
402
+ paths["checkpoints"] / f"{model_name}_best.pth",
403
+ )
404
+ logger.info(" New best F1: %.4f", best_f1)
405
+ else:
406
+ patience_counter += 1
407
+
408
+ # Early stopping
409
+ if es_cfg.get("enabled", True) and patience_counter >= patience:
410
+ logger.info("Early stopping triggered at epoch %d", epoch + 1)
411
+ break
412
+
413
+ writer.close()
414
+ logger.info("Training complete. Best F1: %.4f", best_f1)
415
+
416
+
417
+ if __name__ == "__main__":
418
+ main()
utils/__init__.py ADDED
File without changes
utils/losses.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loss functions for binary change detection.
2
+
3
+ Provides BCEDiceLoss (default) and FocalLoss, both operating on raw logits.
4
+ A factory function ``get_loss`` reads the project config and returns the
5
+ selected loss module.
6
+ """
7
+
8
+ from typing import Any, Dict
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class BCEDiceLoss(nn.Module):
16
+ """Combined Binary Cross-Entropy and Dice Loss.
17
+
18
+ Both components operate on raw logits — sigmoid is applied internally so
19
+ the caller should **not** pre-apply it.
20
+
21
+ Args:
22
+ bce_weight: Scalar weight for the BCE component.
23
+ dice_weight: Scalar weight for the Dice component.
24
+ smooth: Smoothing constant for Dice to avoid division by zero.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ bce_weight: float = 0.5,
30
+ dice_weight: float = 0.5,
31
+ smooth: float = 1.0,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.bce_weight = bce_weight
35
+ self.dice_weight = dice_weight
36
+ self.smooth = smooth
37
+
38
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
39
+ """Compute the combined BCE + Dice loss.
40
+
41
+ Args:
42
+ logits: Raw model output of shape ``[B, 1, H, W]``.
43
+ targets: Binary ground-truth masks of shape ``[B, 1, H, W]``
44
+ with values in {0, 1}.
45
+
46
+ Returns:
47
+ Scalar loss tensor on the same device as the inputs.
48
+ """
49
+ # --- BCE component (numerically stable, operates on logits) ---
50
+ bce_loss = F.binary_cross_entropy_with_logits(logits, targets)
51
+
52
+ # --- Dice component ---
53
+ probs = torch.sigmoid(logits)
54
+ # Flatten spatial dims per sample for stable dice computation
55
+ probs_flat = probs.view(probs.size(0), -1)
56
+ targets_flat = targets.view(targets.size(0), -1)
57
+
58
+ intersection = (probs_flat * targets_flat).sum(dim=1)
59
+ union = probs_flat.sum(dim=1) + targets_flat.sum(dim=1)
60
+ dice_score = (2.0 * intersection + self.smooth) / (union + self.smooth)
61
+ dice_loss = 1.0 - dice_score.mean()
62
+
63
+ return self.bce_weight * bce_loss + self.dice_weight * dice_loss
64
+
65
+
66
+ class FocalLoss(nn.Module):
67
+ """Focal Loss for addressing class imbalance in change detection.
68
+
69
+ Down-weights well-classified (easy) pixels so the model focuses on hard
70
+ examples near the decision boundary. Operates on raw logits.
71
+
72
+ Args:
73
+ alpha: Balancing factor for the positive class (1 − alpha for negative).
74
+ gamma: Focusing exponent — higher values down-weight easy examples more.
75
+ """
76
+
77
+ def __init__(self, alpha: float = 0.25, gamma: float = 2.0) -> None:
78
+ super().__init__()
79
+ self.alpha = alpha
80
+ self.gamma = gamma
81
+
82
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
83
+ """Compute focal loss.
84
+
85
+ Args:
86
+ logits: Raw model output of shape ``[B, 1, H, W]``.
87
+ targets: Binary ground-truth masks of shape ``[B, 1, H, W]``
88
+ with values in {0, 1}.
89
+
90
+ Returns:
91
+ Scalar loss tensor on the same device as the inputs.
92
+ """
93
+ # Per-pixel BCE (unreduced)
94
+ bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
95
+
96
+ probs = torch.sigmoid(logits)
97
+ # p_t = probability of the true class
98
+ p_t = probs * targets + (1.0 - probs) * (1.0 - targets)
99
+ # alpha_t = alpha for positives, (1-alpha) for negatives
100
+ alpha_t = self.alpha * targets + (1.0 - self.alpha) * (1.0 - targets)
101
+
102
+ focal_weight = alpha_t * (1.0 - p_t) ** self.gamma
103
+ return (focal_weight * bce).mean()
104
+
105
+
106
+ def get_loss(config: Dict[str, Any]) -> nn.Module:
107
+ """Factory function — instantiate a loss module from the project config.
108
+
109
+ Reads ``config["loss"]["name"]`` to select the loss type and extracts
110
+ the matching sub-key for constructor arguments.
111
+
112
+ Args:
113
+ config: Full project config dict (as loaded from ``config.yaml``).
114
+
115
+ Returns:
116
+ An ``nn.Module`` loss function ready for ``loss(logits, targets)``.
117
+
118
+ Raises:
119
+ ValueError: If the requested loss name is not recognised.
120
+ """
121
+ loss_cfg = config.get("loss", {})
122
+ name = loss_cfg.get("name", "bce_dice")
123
+
124
+ if name == "bce_dice":
125
+ params = loss_cfg.get("bce_dice", {})
126
+ return BCEDiceLoss(
127
+ bce_weight=params.get("bce_weight", 0.5),
128
+ dice_weight=params.get("dice_weight", 0.5),
129
+ )
130
+ elif name == "focal":
131
+ params = loss_cfg.get("focal", {})
132
+ return FocalLoss(
133
+ alpha=params.get("alpha", 0.25),
134
+ gamma=params.get("gamma", 2.0),
135
+ )
136
+ else:
137
+ raise ValueError(
138
+ f"Unknown loss '{name}'. Choose from: bce_dice, focal"
139
+ )
utils/metrics.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation metrics for binary change detection.
2
+
3
+ Provides a ``ConfusionMatrix`` accumulator, standalone metric functions, and a
4
+ high-level ``MetricTracker`` that accepts raw logits and handles sigmoid +
5
+ thresholding internally.
6
+
7
+ All tensor operations stay on GPU until the final ``.item()`` call inside
8
+ ``compute()`` so there is no unnecessary device transfer during the hot loop.
9
+ """
10
+
11
+ from typing import Dict
12
+
13
+ import torch
14
+
15
+ # Small constant to prevent division-by-zero in metric formulas.
16
+ _EPS: float = 1e-7
17
+
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Low-level confusion-matrix accumulator
21
+ # ---------------------------------------------------------------------------
22
+
23
+ class ConfusionMatrix:
24
+ """Accumulates TP / FP / FN / TN counts across batches.
25
+
26
+ Counts are kept as plain Python ints (moved off GPU via a single
27
+ ``.item()`` per update call) so that accumulated values never overflow
28
+ a GPU scalar.
29
+
30
+ Example::
31
+
32
+ cm = ConfusionMatrix()
33
+ for preds, targets in loader:
34
+ cm.update(preds, targets)
35
+ metrics = cm.compute()
36
+ """
37
+
38
+ def __init__(self) -> None:
39
+ self.reset()
40
+
41
+ def reset(self) -> None:
42
+ """Reset all counters to zero."""
43
+ self.tp: int = 0
44
+ self.fp: int = 0
45
+ self.fn: int = 0
46
+ self.tn: int = 0
47
+
48
+ def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None:
49
+ """Accumulate one batch of binary predictions.
50
+
51
+ All boolean logic runs on whatever device the tensors live on; only
52
+ the four resulting scalars are moved to CPU via ``.item()``.
53
+
54
+ Args:
55
+ preds: Binary predictions ``[B, 1, H, W]`` with values in {0, 1}.
56
+ targets: Ground-truth masks ``[B, 1, H, W]`` with values in {0, 1}.
57
+ """
58
+ p = preds.bool().flatten()
59
+ t = targets.bool().flatten()
60
+
61
+ self.tp += (p & t).sum().item()
62
+ self.fp += (p & ~t).sum().item()
63
+ self.fn += (~p & t).sum().item()
64
+ self.tn += (~p & ~t).sum().item()
65
+
66
+ def compute(self) -> Dict[str, float]:
67
+ """Derive all metrics from the accumulated counts.
68
+
69
+ Returns:
70
+ Dict with keys ``'f1'``, ``'iou'``, ``'precision'``, ``'recall'``,
71
+ ``'oa'`` — each a plain Python float.
72
+ """
73
+ precision = self.tp / (self.tp + self.fp + _EPS)
74
+ recall = self.tp / (self.tp + self.fn + _EPS)
75
+ f1 = 2.0 * precision * recall / (precision + recall + _EPS)
76
+ iou = self.tp / (self.tp + self.fp + self.fn + _EPS)
77
+ oa = (self.tp + self.tn) / (self.tp + self.fp + self.fn + self.tn + _EPS)
78
+
79
+ return {
80
+ "f1": f1,
81
+ "iou": iou,
82
+ "precision": precision,
83
+ "recall": recall,
84
+ "oa": oa,
85
+ }
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Standalone convenience functions (single-batch, binary inputs)
90
+ # ---------------------------------------------------------------------------
91
+
92
+ def _quick_cm(preds: torch.Tensor, targets: torch.Tensor) -> ConfusionMatrix:
93
+ """Create and populate a ConfusionMatrix from a single batch.
94
+
95
+ Args:
96
+ preds: Binary predictions ``[B, 1, H, W]``.
97
+ targets: Ground-truth masks ``[B, 1, H, W]``.
98
+
99
+ Returns:
100
+ Populated ``ConfusionMatrix`` instance.
101
+ """
102
+ cm = ConfusionMatrix()
103
+ cm.update(preds, targets)
104
+ return cm
105
+
106
+
107
+ def compute_f1(preds: torch.Tensor, targets: torch.Tensor) -> float:
108
+ """Compute F1 score for a single batch.
109
+
110
+ Args:
111
+ preds: Binary predictions ``[B, 1, H, W]``.
112
+ targets: Ground-truth masks ``[B, 1, H, W]``.
113
+
114
+ Returns:
115
+ F1 score as a float in [0, 1].
116
+ """
117
+ return _quick_cm(preds, targets).compute()["f1"]
118
+
119
+
120
+ def compute_iou(preds: torch.Tensor, targets: torch.Tensor) -> float:
121
+ """Compute IoU (Jaccard index) for a single batch.
122
+
123
+ Args:
124
+ preds: Binary predictions ``[B, 1, H, W]``.
125
+ targets: Ground-truth masks ``[B, 1, H, W]``.
126
+
127
+ Returns:
128
+ IoU score as a float in [0, 1].
129
+ """
130
+ return _quick_cm(preds, targets).compute()["iou"]
131
+
132
+
133
+ def compute_precision(preds: torch.Tensor, targets: torch.Tensor) -> float:
134
+ """Compute precision for a single batch.
135
+
136
+ Args:
137
+ preds: Binary predictions ``[B, 1, H, W]``.
138
+ targets: Ground-truth masks ``[B, 1, H, W]``.
139
+
140
+ Returns:
141
+ Precision score as a float in [0, 1].
142
+ """
143
+ return _quick_cm(preds, targets).compute()["precision"]
144
+
145
+
146
+ def compute_recall(preds: torch.Tensor, targets: torch.Tensor) -> float:
147
+ """Compute recall for a single batch.
148
+
149
+ Args:
150
+ preds: Binary predictions ``[B, 1, H, W]``.
151
+ targets: Ground-truth masks ``[B, 1, H, W]``.
152
+
153
+ Returns:
154
+ Recall score as a float in [0, 1].
155
+ """
156
+ return _quick_cm(preds, targets).compute()["recall"]
157
+
158
+
159
+ def compute_oa(preds: torch.Tensor, targets: torch.Tensor) -> float:
160
+ """Compute overall accuracy for a single batch.
161
+
162
+ Args:
163
+ preds: Binary predictions ``[B, 1, H, W]``.
164
+ targets: Ground-truth masks ``[B, 1, H, W]``.
165
+
166
+ Returns:
167
+ Overall accuracy as a float in [0, 1].
168
+ """
169
+ return _quick_cm(preds, targets).compute()["oa"]
170
+
171
+
172
+ # ---------------------------------------------------------------------------
173
+ # High-level tracker (accepts raw logits)
174
+ # ---------------------------------------------------------------------------
175
+
176
+ class MetricTracker:
177
+ """End-to-end metric tracker for training / validation loops.
178
+
179
+ Wraps a ``ConfusionMatrix`` and transparently applies sigmoid +
180
+ thresholding to raw model logits before accumulating counts.
181
+
182
+ Args:
183
+ threshold: Decision threshold applied after sigmoid (default 0.5).
184
+
185
+ Example::
186
+
187
+ tracker = MetricTracker(threshold=0.5)
188
+ for batch in val_loader:
189
+ logits = model(batch["A"], batch["B"])
190
+ tracker.update(logits, batch["mask"])
191
+ results = tracker.compute() # {"f1": ..., "iou": ..., ...}
192
+ tracker.reset()
193
+ """
194
+
195
+ def __init__(self, threshold: float = 0.5) -> None:
196
+ self.threshold = threshold
197
+ self.cm = ConfusionMatrix()
198
+
199
+ def reset(self) -> None:
200
+ """Reset the internal confusion matrix."""
201
+ self.cm.reset()
202
+
203
+ @torch.no_grad()
204
+ def update(self, logits: torch.Tensor, targets: torch.Tensor) -> None:
205
+ """Apply sigmoid + threshold and accumulate counts.
206
+
207
+ This method is wrapped with ``@torch.no_grad()`` so it can be
208
+ called safely inside a validation loop without affecting autograd.
209
+ All operations run on the input tensor's device.
210
+
211
+ Args:
212
+ logits: Raw model output ``[B, 1, H, W]`` (pre-sigmoid).
213
+ targets: Binary ground-truth masks ``[B, 1, H, W]`` with
214
+ values in {0, 1}.
215
+ """
216
+ preds = (torch.sigmoid(logits) >= self.threshold).float()
217
+ self.cm.update(preds, targets)
218
+
219
+ def compute(self) -> Dict[str, float]:
220
+ """Compute all metrics from accumulated counts.
221
+
222
+ Returns:
223
+ Dict with keys ``'f1'``, ``'iou'``, ``'precision'``, ``'recall'``,
224
+ ``'oa'``.
225
+ """
226
+ return self.cm.compute()
utils/visualization.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization utilities for change detection results.
2
+
3
+ Provides functions to plot predictions, overlay change maps, and track
4
+ training metrics over time.
5
+ """
6
+
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ def denormalize(
16
+ img: np.ndarray,
17
+ mean: tuple = (0.485, 0.456, 0.406),
18
+ std: tuple = (0.229, 0.224, 0.225),
19
+ ) -> np.ndarray:
20
+ """Reverse ImageNet normalization for display.
21
+
22
+ Args:
23
+ img: Normalized image array [H, W, 3].
24
+ mean: Channel means used for normalization.
25
+ std: Channel stds used for normalization.
26
+
27
+ Returns:
28
+ Denormalized image clipped to [0, 1].
29
+ """
30
+ img = img * np.array(std) + np.array(mean)
31
+ return np.clip(img, 0, 1)
32
+
33
+
34
+ def plot_prediction(
35
+ img_a: torch.Tensor,
36
+ img_b: torch.Tensor,
37
+ mask_gt: torch.Tensor,
38
+ mask_pred: torch.Tensor,
39
+ save_path: Optional[Path] = None,
40
+ ) -> plt.Figure:
41
+ """Plot a single change detection prediction.
42
+
43
+ Shows: Before | After | Ground Truth | Prediction in a 1x4 grid.
44
+
45
+ Args:
46
+ img_a: Before image tensor [3, H, W] (normalized).
47
+ img_b: After image tensor [3, H, W] (normalized).
48
+ mask_gt: Ground truth mask [1, H, W] (binary).
49
+ mask_pred: Predicted mask [1, H, W] (binary or probability).
50
+ save_path: Optional path to save the figure.
51
+
52
+ Returns:
53
+ Matplotlib figure.
54
+ """
55
+ fig, axes = plt.subplots(1, 4, figsize=(16, 4))
56
+
57
+ # Convert tensors to numpy
58
+ a = denormalize(img_a.permute(1, 2, 0).cpu().numpy())
59
+ b = denormalize(img_b.permute(1, 2, 0).cpu().numpy())
60
+ gt = mask_gt.squeeze(0).cpu().numpy()
61
+ pred = mask_pred.squeeze(0).cpu().numpy()
62
+
63
+ titles = ["Before (A)", "After (B)", "Ground Truth", "Prediction"]
64
+ images = [a, b, gt, pred]
65
+ cmaps = [None, None, "gray", "gray"]
66
+
67
+ for ax, img, title, cmap in zip(axes, images, titles, cmaps):
68
+ ax.imshow(img, cmap=cmap, vmin=0, vmax=1)
69
+ ax.set_title(title)
70
+ ax.axis("off")
71
+
72
+ plt.tight_layout()
73
+
74
+ if save_path is not None:
75
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
76
+
77
+ return fig
78
+
79
+
80
+ def overlay_changes(
81
+ img_b: torch.Tensor,
82
+ mask_pred: torch.Tensor,
83
+ alpha: float = 0.4,
84
+ color: tuple = (1.0, 0.0, 0.0),
85
+ ) -> np.ndarray:
86
+ """Overlay predicted change mask on the after image.
87
+
88
+ Args:
89
+ img_b: After image tensor [3, H, W] (normalized).
90
+ mask_pred: Predicted binary mask [1, H, W].
91
+ alpha: Overlay transparency.
92
+ color: RGB color for the overlay (default: red).
93
+
94
+ Returns:
95
+ Overlaid image as numpy array [H, W, 3].
96
+ """
97
+ b = denormalize(img_b.permute(1, 2, 0).cpu().numpy())
98
+ mask = mask_pred.squeeze(0).cpu().numpy()
99
+
100
+ overlay = b.copy()
101
+ for c in range(3):
102
+ overlay[:, :, c] = np.where(
103
+ mask > 0.5,
104
+ b[:, :, c] * (1 - alpha) + color[c] * alpha,
105
+ b[:, :, c],
106
+ )
107
+ return overlay
108
+
109
+
110
+ def plot_metrics_history(
111
+ history: Dict[str, List[float]],
112
+ save_path: Optional[Path] = None,
113
+ ) -> plt.Figure:
114
+ """Plot training metric curves over epochs.
115
+
116
+ Args:
117
+ history: Dict mapping metric names to lists of per-epoch values.
118
+ save_path: Optional path to save the figure.
119
+
120
+ Returns:
121
+ Matplotlib figure.
122
+ """
123
+ n_metrics = len(history)
124
+ fig, axes = plt.subplots(1, n_metrics, figsize=(5 * n_metrics, 4))
125
+
126
+ if n_metrics == 1:
127
+ axes = [axes]
128
+
129
+ for ax, (name, values) in zip(axes, history.items()):
130
+ ax.plot(values, marker="o", markersize=2)
131
+ ax.set_title(name)
132
+ ax.set_xlabel("Epoch")
133
+ ax.set_ylabel(name)
134
+ ax.grid(True, alpha=0.3)
135
+
136
+ plt.tight_layout()
137
+
138
+ if save_path is not None:
139
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
140
+
141
+ return fig