| # Cleaned Code |
| ```python |
| import os |
| import math |
| import zipfile |
| import urllib.request |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| |
| from torch.utils.data import DataLoader |
| from torchvision import datasets, transforms |
| |
| |
| # ========================================================= |
| # 1. TINY-IMAGENET DOWNLOAD + PREPARATION |
| # ========================================================= |
| |
| def prepare_tiny_imagenet(): |
| """ |
| Downloads and extracts Tiny-ImageNet if not already present. |
| |
| Returns: |
| train_dir, val_dir |
| """ |
| |
| dataset_url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" |
| |
| zip_path = "./tiny-imagenet-200.zip" |
| |
| extract_path = "./tiny-imagenet-200" |
| |
| # ----------------------------------------------------- |
| # Download dataset archive |
| # ----------------------------------------------------- |
| if not os.path.exists(zip_path): |
| |
| print( |
| "Downloading Tiny-ImageNet (~230MB)... " |
| "Please wait..." |
| ) |
| |
| urllib.request.urlretrieve( |
| dataset_url, |
| zip_path |
| ) |
| |
| print("Download complete!") |
| |
| # ----------------------------------------------------- |
| # Extract dataset archive |
| # ----------------------------------------------------- |
| if not os.path.exists(extract_path): |
| |
| print("Extracting dataset...") |
| |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| zip_ref.extractall("./") |
| |
| print("Extraction complete!") |
| |
| return ( |
| os.path.join(extract_path, "train"), |
| os.path.join(extract_path, "val") |
| ) |
| |
| |
| train_dir, val_dir = prepare_tiny_imagenet() |
| |
| |
| # ========================================================= |
| # 2. VALIDATION FOLDER RESTRUCTURING |
| # ========================================================= |
| # |
| # Tiny-ImageNet validation images are originally placed |
| # in a single shared folder. |
| # |
| # This section reorganizes them into class-specific |
| # folders so torchvision.datasets.ImageFolder can |
| # load them correctly. |
| # |
| |
| val_img_dir = "./tiny-imagenet-200/val/images" |
| |
| val_annotations = ( |
| "./tiny-imagenet-200/val/val_annotations.txt" |
| ) |
| |
| if os.path.exists(val_img_dir): |
| |
| print( |
| "Reorganizing Tiny-ImageNet validation " |
| "folder structure..." |
| ) |
| |
| with open(val_annotations, "r") as f: |
| lines = f.readlines() |
| |
| for line in lines: |
| |
| parts = line.strip().split("\t") |
| |
| img_name = parts[0] |
| class_name = parts[1] |
| |
| class_dir = os.path.join( |
| "./tiny-imagenet-200/val", |
| class_name |
| ) |
| |
| os.makedirs(class_dir, exist_ok=True) |
| |
| src_path = os.path.join( |
| val_img_dir, |
| img_name |
| ) |
| |
| dst_path = os.path.join( |
| class_dir, |
| img_name |
| ) |
| |
| if os.path.exists(src_path): |
| os.rename(src_path, dst_path) |
| |
| os.rmdir(val_img_dir) |
| |
| print( |
| "Validation folder restructuring complete!" |
| ) |
| |
| |
| # ========================================================= |
| # 3. DATA AUGMENTATION + NORMALIZATION |
| # ========================================================= |
| |
| transform_train = transforms.Compose([ |
| |
| # Horizontal augmentation |
| transforms.RandomHorizontalFlip(), |
| |
| # Mild rotational augmentation |
| transforms.RandomRotation(15), |
| |
| transforms.ToTensor(), |
| |
| # Tiny-ImageNet normalization statistics |
| transforms.Normalize( |
| (0.4802, 0.4481, 0.3975), |
| (0.2302, 0.2265, 0.2262) |
| ) |
| ]) |
| |
| transform_val = transforms.Compose([ |
| |
| transforms.ToTensor(), |
| |
| transforms.Normalize( |
| (0.4802, 0.4481, 0.3975), |
| (0.2302, 0.2265, 0.2262) |
| ) |
| ]) |
| |
| |
| # ========================================================= |
| # 4. DATASET + DATALOADER SETUP |
| # ========================================================= |
| |
| train_dataset = datasets.ImageFolder( |
| root=train_dir, |
| transform=transform_train |
| ) |
| |
| val_dataset = datasets.ImageFolder( |
| root=val_dir, |
| transform=transform_val |
| ) |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=128, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True |
| ) |
| |
| val_loader = DataLoader( |
| val_dataset, |
| batch_size=256, |
| shuffle=False, |
| num_workers=2, |
| pin_memory=True |
| ) |
| |
| |
| # ========================================================= |
| # 5. CORE RELATIONAL LAYER — LOOKTHEM LAYER |
| # ========================================================= |
| |
| class LookThemLayer(nn.Module): |
| """ |
| Token-relational processing layer. |
| |
| Each token owns two independent micro-networks |
| whose outputs are compared against every other |
| token using ratio-based relational interactions. |
| |
| The interaction maps are transformed and then |
| redistributed back into token-space. |
| """ |
| |
| def __init__(self, num_tokens, in_features, hidden_dim): |
| |
| super(LookThemLayer, self).__init__() |
| |
| self.num_tokens = num_tokens |
| self.in_features = in_features |
| |
| # ================================================= |
| # BRANCH 1 PARAMETERS |
| # ================================================= |
| self.mod1_w1 = nn.Parameter( |
| torch.randn( |
| num_tokens, |
| in_features, |
| hidden_dim |
| ) |
| ) |
| |
| self.mod1_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
| |
| self.mod1_w2 = nn.Parameter( |
| torch.randn( |
| num_tokens, |
| hidden_dim, |
| 1 |
| ) |
| ) |
| |
| self.mod1_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
| |
| # ================================================= |
| # BRANCH 2 PARAMETERS |
| # ================================================= |
| self.mod2_w1 = nn.Parameter( |
| torch.randn( |
| num_tokens, |
| in_features, |
| hidden_dim |
| ) |
| ) |
| |
| self.mod2_b1 = nn.Parameter( |
| torch.zeros(num_tokens, hidden_dim) |
| ) |
| |
| self.mod2_w2 = nn.Parameter( |
| torch.randn( |
| num_tokens, |
| hidden_dim, |
| 1 |
| ) |
| ) |
| |
| self.mod2_b2 = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
| |
| # ================================================= |
| # RELATIONAL TRANSFORMATION PARAMETERS |
| # ================================================= |
| self.trans_w = nn.Parameter( |
| torch.randn(num_tokens, 1, 1) |
| ) |
| |
| self.trans_b = nn.Parameter( |
| torch.zeros(num_tokens, 1) |
| ) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| """ |
| Kaiming initialization for all learnable |
| projection matrices. |
| """ |
| |
| for w in [ |
| self.mod1_w1, |
| self.mod2_w1, |
| self.mod1_w2, |
| self.mod2_w2, |
| self.trans_w |
| ]: |
| nn.init.kaiming_uniform_( |
| w, |
| a=math.sqrt(5) |
| ) |
| |
| def forward(self, x): |
| """ |
| Input shape: |
| [B, Tokens, Features] |
| |
| Output shape: |
| [B, Tokens, Features] |
| """ |
| |
| N = self.num_tokens |
| |
| # ================================================= |
| # BRANCH 1 FORWARD PASS |
| # ================================================= |
| h1 = ( |
| torch.einsum( |
| 'bti,tij->btj', |
| x, |
| self.mod1_w1 |
| ) |
| + self.mod1_b1 |
| ) |
| |
| out_m1 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h1), |
| self.mod1_w2 |
| ) |
| + self.mod1_b2 |
| ) |
| |
| # ================================================= |
| # BRANCH 2 FORWARD PASS |
| # ================================================= |
| h2 = ( |
| torch.einsum( |
| 'bti,tij->btj', |
| x, |
| self.mod2_w1 |
| ) |
| + self.mod2_b1 |
| ) |
| |
| out_m2 = ( |
| torch.einsum( |
| 'btj,tjk->btk', |
| F.gelu(h2), |
| self.mod2_w2 |
| ) |
| + self.mod2_b2 |
| ) |
| |
| # Numerical stabilization |
| out_m2_safe = out_m2 + 1e-5 |
| |
| # ================================================= |
| # PAIRWISE TOKEN RELATIONAL COMPARISON |
| # ================================================= |
| |
| compare = torch.tanh( |
| out_m1.unsqueeze(2) / |
| out_m2_safe.unsqueeze(1) |
| ) |
| |
| compare2 = torch.tanh( |
| out_m1.unsqueeze(1) / |
| out_m2_safe.unsqueeze(2) |
| ) |
| |
| # ================================================= |
| # RELATIONAL MAP TRANSFORMATION |
| # ================================================= |
| bias_reshaped = self.trans_b.view( |
| 1, |
| 1, |
| N, |
| 1 |
| ) |
| |
| trans_compare = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
| |
| trans_compare2 = ( |
| torch.einsum( |
| 'bije,jef->bijf', |
| compare2, |
| self.trans_w |
| ) |
| + bias_reshaped |
| ) |
| |
| # ================================================= |
| # BIDIRECTIONAL INTERACTION FUSION |
| # ================================================= |
| interaction = ( |
| trans_compare * x.unsqueeze(2) |
| + trans_compare2 * x.unsqueeze(1) |
| ) / 2 |
| |
| # Remove self-interaction |
| mask = 1.0 - torch.eye( |
| N, |
| device=x.device |
| ) |
| |
| interaction_masked = ( |
| interaction * |
| mask.view(1, N, N, 1) |
| ) |
| |
| # Aggregate external token interactions |
| return ( |
| interaction_masked.sum(dim=2) |
| / (N - 1.0) |
| ) |
| |
| |
| # ========================================================= |
| # 6. MAIN ARCHITECTURE — LOOKTHEM V5 |
| # ========================================================= |
| |
| class LookThemV5(nn.Module): |
| """ |
| Dual-stream asymmetric relational architecture. |
| |
| Stream A: |
| High-resolution grayscale macro-structure stream. |
| |
| Stream B: |
| RGB color-essence stream compressed into |
| lower spatial resolution. |
| |
| Both streams are fused at feature-level and |
| processed through the relational LookThem core. |
| """ |
| |
| def __init__(self): |
| |
| super(LookThemV5, self).__init__() |
| |
| # ================================================= |
| # RGB → GRAYSCALE CONVERSION WEIGHTS |
| # ================================================= |
| self.register_buffer( |
| 'grayscale_weights', |
| torch.tensor( |
| [0.299, 0.587, 0.114] |
| ).view(1, 3, 1, 1) |
| ) |
| |
| # ================================================= |
| # STREAM A — MACRO STRUCTURE STREAM |
| # ================================================= |
| # |
| # Preserves higher spatial resolution (16x16) |
| # to retain broader structural information. |
| # |
| self.stream_a = nn.Sequential( |
| |
| nn.Conv2d( |
| 1, |
| 16, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 16, |
| 32, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(32), |
| nn.GELU() |
| ) |
| |
| # ================================================= |
| # TOKEN BRIDGE |
| # ================================================= |
| # |
| # Compresses spatial dimension: |
| # |
| # 256 spatial positions → 64 tokens |
| # |
| # while preserving feature channels. |
| # |
| self.token_bridge = nn.Linear(256, 64) |
| |
| # ================================================= |
| # STREAM B — COLOR ESSENCE STREAM |
| # ================================================= |
| # |
| # RGB stream reduced into 8x8 spatial layout |
| # using pure stride-based standard convolutions. |
| # |
| self.stream_b = nn.Sequential( |
| |
| nn.Conv2d( |
| 3, |
| 16, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(16), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 16, |
| 32, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(32), |
| nn.GELU(), |
| |
| nn.Conv2d( |
| 32, |
| 32, |
| kernel_size=3, |
| stride=2, |
| padding=1 |
| ), |
| nn.BatchNorm2d(32), |
| nn.GELU() |
| ) |
| |
| # ================================================= |
| # RELATIONAL COGNITION CORE |
| # ================================================= |
| self.lookthem = LookThemLayer( |
| num_tokens=64, |
| in_features=64, |
| hidden_dim=32 |
| ) |
| |
| # ================================================= |
| # CLASSIFIER HEAD |
| # ================================================= |
| # |
| # Flattened relational token representation |
| # followed by lightweight anti-overfit head. |
| # |
| self.classifier = nn.Sequential( |
| |
| nn.Flatten(), |
| |
| nn.Linear(64 * 64, 256), |
| |
| nn.ReLU(), |
| |
| nn.Dropout(0.4), |
| |
| nn.Linear(256, 200) |
| ) |
| |
| def forward(self, x): |
| |
| batch_size = x.size(0) |
| |
| # ================================================= |
| # STREAM A — GRAYSCALE MACRO EXTRACTION |
| # ================================================= |
| |
| # Convert RGB image into grayscale |
| x_gray = torch.sum( |
| x * self.grayscale_weights, |
| dim=1, |
| keepdim=True |
| ) |
| |
| feat_a = self.stream_a(x_gray) |
| |
| # Shape: |
| # [B, 32, 16, 16] |
| |
| feat_a_flat = feat_a.view( |
| batch_size, |
| 32, |
| 256 |
| ) |
| |
| # Spatial compression: |
| # 256 → 64 tokens |
| feat_a_compressed = self.token_bridge( |
| feat_a_flat |
| ) |
| |
| feat_a_tokens = ( |
| feat_a_compressed.transpose(1, 2) |
| ) |
| |
| # Final shape: |
| # [B, 64 Tokens, 32 Features] |
| |
| # ================================================= |
| # STREAM B — RGB COLOR EXTRACTION |
| # ================================================= |
| |
| feat_b = self.stream_b(x) |
| |
| feat_b_tokens = ( |
| feat_b |
| .view(batch_size, 32, 64) |
| .transpose(1, 2) |
| ) |
| |
| # Final shape: |
| # [B, 64 Tokens, 32 Features] |
| |
| # ================================================= |
| # ASYMMETRIC FEATURE FUSION |
| # ================================================= |
| # |
| # Token count remains fixed while |
| # feature dimensionality is doubled. |
| # |
| tokens_combined = torch.cat( |
| [feat_a_tokens, feat_b_tokens], |
| dim=2 |
| ) |
| |
| # Final shape: |
| # [B, 64 Tokens, 64 Features] |
| |
| # ================================================= |
| # RELATIONAL COGNITION |
| # ================================================= |
| out_lookthem = self.lookthem( |
| tokens_combined |
| ) |
| |
| # ================================================= |
| # CLASSIFICATION |
| # ================================================= |
| return self.classifier(out_lookthem) |
| |
| |
| # ========================================================= |
| # 7. TRAINING RUNTIME + CHECKPOINT SYSTEM |
| # ========================================================= |
| |
| device = torch.device( |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| |
| model = LookThemV5().to(device) |
| |
| criterion = nn.CrossEntropyLoss() |
| |
| optimizer = optim.Adam( |
| model.parameters(), |
| lr=0.001, |
| weight_decay=1e-4 |
| ) |
| |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, |
| T_max=20 |
| ) |
| |
| start_epoch = 0 |
| |
| checkpoint_path = "lookthem_v5_checkpoint.pth" |
| |
| |
| # ========================================================= |
| # CHECKPOINT RESUME |
| # ========================================================= |
| |
| if os.path.exists(checkpoint_path): |
| |
| print( |
| "Checkpoint detected. " |
| "Resuming previous experiment..." |
| ) |
| |
| checkpoint = torch.load(checkpoint_path) |
| |
| model.load_state_dict( |
| checkpoint['model_state_dict'] |
| ) |
| |
| optimizer.load_state_dict( |
| checkpoint['optimizer_state_dict'] |
| ) |
| |
| scheduler.load_state_dict( |
| checkpoint['scheduler_state_dict'] |
| ) |
| |
| start_epoch = checkpoint['epoch'] |
| |
| print( |
| f"Successfully resumed from " |
| f"epoch {start_epoch + 1}" |
| ) |
| |
| |
| print( |
| f"Starting LookThem V5 " |
| f"(Asymmetric Fusion) on {device}..." |
| ) |
| |
| |
| # ========================================================= |
| # 8. TRAINING LOOP |
| # ========================================================= |
| |
| for epoch in range(start_epoch, 20): |
| |
| model.train() |
| |
| total_loss = 0 |
| correct = 0 |
| total = 0 |
| |
| for data, target in train_loader: |
| |
| data = data.to(device) |
| |
| target = target.to(device) |
| |
| optimizer.zero_grad() |
| |
| output = model(data) |
| |
| loss = criterion(output, target) |
| |
| loss.backward() |
| |
| optimizer.step() |
| |
| total_loss += loss.item() |
| |
| _, predicted = output.max(1) |
| |
| total += target.size(0) |
| |
| correct += predicted.eq(target).sum().item() |
| |
| scheduler.step() |
| |
| acc = 100. * correct / total |
| |
| current_lr = optimizer.param_groups[0]['lr'] |
| |
| print( |
| f"Epoch {epoch+1:02d}/20 | " |
| f"Train Loss: " |
| f"{total_loss / len(train_loader):.4f} | " |
| f"Train Acc: {acc:.2f}% | " |
| f"LR: {current_lr:.6f}" |
| ) |
| |
| # ----------------------------------------------------- |
| # Periodic checkpoint save |
| # ----------------------------------------------------- |
| if (epoch + 1) % 5 == 0: |
| |
| torch.save({ |
| |
| 'epoch': epoch + 1, |
| |
| 'model_state_dict': |
| model.state_dict(), |
| |
| 'optimizer_state_dict': |
| optimizer.state_dict(), |
| |
| 'scheduler_state_dict': |
| scheduler.state_dict(), |
| |
| }, checkpoint_path) |
| |
| print( |
| f"[CHECKPOINT] " |
| f"Epoch {epoch+1} saved successfully." |
| ) |
| |
| |
| # ========================================================= |
| # 9. FINAL VALIDATION |
| # ========================================================= |
| |
| model.eval() |
| |
| test_loss = 0 |
| test_correct = 0 |
| test_total = 0 |
| |
| print("\nStarting final validation...") |
| |
| with torch.no_grad(): |
| |
| for data, target in val_loader: |
| |
| data = data.to(device) |
| |
| target = target.to(device) |
| |
| output = model(data) |
| |
| loss = criterion(output, target) |
| |
| test_loss += loss.item() |
| |
| _, predicted = output.max(1) |
| |
| test_total += target.size(0) |
| |
| test_correct += predicted.eq(target).sum().item() |
| |
| final_test_acc = ( |
| 100. * test_correct / test_total |
| ) |
| |
| print("=== FINAL LOOKTHEM V5 RESULTS ===") |
| |
| print( |
| f"Test Loss: " |
| f"{test_loss / len(val_loader):.4f} | " |
| f"Test Accuracy: {final_test_acc:.2f}%" |
| ) |
| |
| # Save final trained weights |
| torch.save( |
| model.state_dict(), |
| "LookThem_V5_Final.pth" |
| ) |
| |
| print( |
| f"Training complete! " |
| f"Final model size: " |
| f"{os.path.getsize('LookThem_V5_Final.pth') / (1024*1024):.2f} MB" |
| ) |
| ``` |