Cleaned Code
import os
import math
import zipfile
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# =========================================================
# 1. TINY-IMAGENET DOWNLOAD + PREPARATION
# =========================================================
def prepare_tiny_imagenet():
"""
Downloads and extracts Tiny-ImageNet if not already present.
Returns:
train_dir, val_dir
"""
dataset_url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
zip_path = "./tiny-imagenet-200.zip"
extract_path = "./tiny-imagenet-200"
# -----------------------------------------------------
# Download dataset archive
# -----------------------------------------------------
if not os.path.exists(zip_path):
print(
"Downloading Tiny-ImageNet (~230MB)... "
"Please wait..."
)
urllib.request.urlretrieve(
dataset_url,
zip_path
)
print("Download complete!")
# -----------------------------------------------------
# Extract dataset archive
# -----------------------------------------------------
if not os.path.exists(extract_path):
print("Extracting dataset...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall("./")
print("Extraction complete!")
return (
os.path.join(extract_path, "train"),
os.path.join(extract_path, "val")
)
train_dir, val_dir = prepare_tiny_imagenet()
# =========================================================
# 2. VALIDATION FOLDER RESTRUCTURING
# =========================================================
#
# Tiny-ImageNet validation images are originally placed
# in a single shared folder.
#
# This section reorganizes them into class-specific
# folders so torchvision.datasets.ImageFolder can
# load them correctly.
#
val_img_dir = "./tiny-imagenet-200/val/images"
val_annotations = (
"./tiny-imagenet-200/val/val_annotations.txt"
)
if os.path.exists(val_img_dir):
print(
"Reorganizing Tiny-ImageNet validation "
"folder structure..."
)
with open(val_annotations, "r") as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split("\t")
img_name = parts[0]
class_name = parts[1]
class_dir = os.path.join(
"./tiny-imagenet-200/val",
class_name
)
os.makedirs(class_dir, exist_ok=True)
src_path = os.path.join(
val_img_dir,
img_name
)
dst_path = os.path.join(
class_dir,
img_name
)
if os.path.exists(src_path):
os.rename(src_path, dst_path)
os.rmdir(val_img_dir)
print(
"Validation folder restructuring complete!"
)
# =========================================================
# 3. DATA AUGMENTATION + NORMALIZATION
# =========================================================
transform_train = transforms.Compose([
# Horizontal augmentation
transforms.RandomHorizontalFlip(),
# Mild rotational augmentation
transforms.RandomRotation(15),
transforms.ToTensor(),
# Tiny-ImageNet normalization statistics
transforms.Normalize(
(0.4802, 0.4481, 0.3975),
(0.2302, 0.2265, 0.2262)
)
])
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.4802, 0.4481, 0.3975),
(0.2302, 0.2265, 0.2262)
)
])
# =========================================================
# 4. DATASET + DATALOADER SETUP
# =========================================================
train_dataset = datasets.ImageFolder(
root=train_dir,
transform=transform_train
)
val_dataset = datasets.ImageFolder(
root=val_dir,
transform=transform_val
)
train_loader = DataLoader(
train_dataset,
batch_size=128,
shuffle=True,
num_workers=2,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=256,
shuffle=False,
num_workers=2,
pin_memory=True
)
# =========================================================
# 5. CORE RELATIONAL LAYER β LOOKTHEM LAYER
# =========================================================
class LookThemLayer(nn.Module):
"""
Token-relational processing layer.
Each token owns two independent micro-networks
whose outputs are compared against every other
token using ratio-based relational interactions.
The interaction maps are transformed and then
redistributed back into token-space.
"""
def __init__(self, num_tokens, in_features, hidden_dim):
super(LookThemLayer, self).__init__()
self.num_tokens = num_tokens
self.in_features = in_features
# =================================================
# BRANCH 1 PARAMETERS
# =================================================
self.mod1_w1 = nn.Parameter(
torch.randn(
num_tokens,
in_features,
hidden_dim
)
)
self.mod1_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod1_w2 = nn.Parameter(
torch.randn(
num_tokens,
hidden_dim,
1
)
)
self.mod1_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
# =================================================
# BRANCH 2 PARAMETERS
# =================================================
self.mod2_w1 = nn.Parameter(
torch.randn(
num_tokens,
in_features,
hidden_dim
)
)
self.mod2_b1 = nn.Parameter(
torch.zeros(num_tokens, hidden_dim)
)
self.mod2_w2 = nn.Parameter(
torch.randn(
num_tokens,
hidden_dim,
1
)
)
self.mod2_b2 = nn.Parameter(
torch.zeros(num_tokens, 1)
)
# =================================================
# RELATIONAL TRANSFORMATION PARAMETERS
# =================================================
self.trans_w = nn.Parameter(
torch.randn(num_tokens, 1, 1)
)
self.trans_b = nn.Parameter(
torch.zeros(num_tokens, 1)
)
self._init_weights()
def _init_weights(self):
"""
Kaiming initialization for all learnable
projection matrices.
"""
for w in [
self.mod1_w1,
self.mod2_w1,
self.mod1_w2,
self.mod2_w2,
self.trans_w
]:
nn.init.kaiming_uniform_(
w,
a=math.sqrt(5)
)
def forward(self, x):
"""
Input shape:
[B, Tokens, Features]
Output shape:
[B, Tokens, Features]
"""
N = self.num_tokens
# =================================================
# BRANCH 1 FORWARD PASS
# =================================================
h1 = (
torch.einsum(
'bti,tij->btj',
x,
self.mod1_w1
)
+ self.mod1_b1
)
out_m1 = (
torch.einsum(
'btj,tjk->btk',
F.gelu(h1),
self.mod1_w2
)
+ self.mod1_b2
)
# =================================================
# BRANCH 2 FORWARD PASS
# =================================================
h2 = (
torch.einsum(
'bti,tij->btj',
x,
self.mod2_w1
)
+ self.mod2_b1
)
out_m2 = (
torch.einsum(
'btj,tjk->btk',
F.gelu(h2),
self.mod2_w2
)
+ self.mod2_b2
)
# Numerical stabilization
out_m2_safe = out_m2 + 1e-5
# =================================================
# PAIRWISE TOKEN RELATIONAL COMPARISON
# =================================================
compare = torch.tanh(
out_m1.unsqueeze(2) /
out_m2_safe.unsqueeze(1)
)
compare2 = torch.tanh(
out_m1.unsqueeze(1) /
out_m2_safe.unsqueeze(2)
)
# =================================================
# RELATIONAL MAP TRANSFORMATION
# =================================================
bias_reshaped = self.trans_b.view(
1,
1,
N,
1
)
trans_compare = (
torch.einsum(
'bije,jef->bijf',
compare,
self.trans_w
)
+ bias_reshaped
)
trans_compare2 = (
torch.einsum(
'bije,jef->bijf',
compare2,
self.trans_w
)
+ bias_reshaped
)
# =================================================
# BIDIRECTIONAL INTERACTION FUSION
# =================================================
interaction = (
trans_compare * x.unsqueeze(2)
+ trans_compare2 * x.unsqueeze(1)
) / 2
# Remove self-interaction
mask = 1.0 - torch.eye(
N,
device=x.device
)
interaction_masked = (
interaction *
mask.view(1, N, N, 1)
)
# Aggregate external token interactions
return (
interaction_masked.sum(dim=2)
/ (N - 1.0)
)
# =========================================================
# 6. MAIN ARCHITECTURE β LOOKTHEM V5
# =========================================================
class LookThemV5(nn.Module):
"""
Dual-stream asymmetric relational architecture.
Stream A:
High-resolution grayscale macro-structure stream.
Stream B:
RGB color-essence stream compressed into
lower spatial resolution.
Both streams are fused at feature-level and
processed through the relational LookThem core.
"""
def __init__(self):
super(LookThemV5, self).__init__()
# =================================================
# RGB β GRAYSCALE CONVERSION WEIGHTS
# =================================================
self.register_buffer(
'grayscale_weights',
torch.tensor(
[0.299, 0.587, 0.114]
).view(1, 3, 1, 1)
)
# =================================================
# STREAM A β MACRO STRUCTURE STREAM
# =================================================
#
# Preserves higher spatial resolution (16x16)
# to retain broader structural information.
#
self.stream_a = nn.Sequential(
nn.Conv2d(
1,
16,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16,
32,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU()
)
# =================================================
# TOKEN BRIDGE
# =================================================
#
# Compresses spatial dimension:
#
# 256 spatial positions β 64 tokens
#
# while preserving feature channels.
#
self.token_bridge = nn.Linear(256, 64)
# =================================================
# STREAM B β COLOR ESSENCE STREAM
# =================================================
#
# RGB stream reduced into 8x8 spatial layout
# using pure stride-based standard convolutions.
#
self.stream_b = nn.Sequential(
nn.Conv2d(
3,
16,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(16),
nn.GELU(),
nn.Conv2d(
16,
32,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(
32,
32,
kernel_size=3,
stride=2,
padding=1
),
nn.BatchNorm2d(32),
nn.GELU()
)
# =================================================
# RELATIONAL COGNITION CORE
# =================================================
self.lookthem = LookThemLayer(
num_tokens=64,
in_features=64,
hidden_dim=32
)
# =================================================
# CLASSIFIER HEAD
# =================================================
#
# Flattened relational token representation
# followed by lightweight anti-overfit head.
#
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 64, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, 200)
)
def forward(self, x):
batch_size = x.size(0)
# =================================================
# STREAM A β GRAYSCALE MACRO EXTRACTION
# =================================================
# Convert RGB image into grayscale
x_gray = torch.sum(
x * self.grayscale_weights,
dim=1,
keepdim=True
)
feat_a = self.stream_a(x_gray)
# Shape:
# [B, 32, 16, 16]
feat_a_flat = feat_a.view(
batch_size,
32,
256
)
# Spatial compression:
# 256 β 64 tokens
feat_a_compressed = self.token_bridge(
feat_a_flat
)
feat_a_tokens = (
feat_a_compressed.transpose(1, 2)
)
# Final shape:
# [B, 64 Tokens, 32 Features]
# =================================================
# STREAM B β RGB COLOR EXTRACTION
# =================================================
feat_b = self.stream_b(x)
feat_b_tokens = (
feat_b
.view(batch_size, 32, 64)
.transpose(1, 2)
)
# Final shape:
# [B, 64 Tokens, 32 Features]
# =================================================
# ASYMMETRIC FEATURE FUSION
# =================================================
#
# Token count remains fixed while
# feature dimensionality is doubled.
#
tokens_combined = torch.cat(
[feat_a_tokens, feat_b_tokens],
dim=2
)
# Final shape:
# [B, 64 Tokens, 64 Features]
# =================================================
# RELATIONAL COGNITION
# =================================================
out_lookthem = self.lookthem(
tokens_combined
)
# =================================================
# CLASSIFICATION
# =================================================
return self.classifier(out_lookthem)
# =========================================================
# 7. TRAINING RUNTIME + CHECKPOINT SYSTEM
# =========================================================
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
model = LookThemV5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
model.parameters(),
lr=0.001,
weight_decay=1e-4
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=20
)
start_epoch = 0
checkpoint_path = "lookthem_v5_checkpoint.pth"
# =========================================================
# CHECKPOINT RESUME
# =========================================================
if os.path.exists(checkpoint_path):
print(
"Checkpoint detected. "
"Resuming previous experiment..."
)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(
checkpoint['model_state_dict']
)
optimizer.load_state_dict(
checkpoint['optimizer_state_dict']
)
scheduler.load_state_dict(
checkpoint['scheduler_state_dict']
)
start_epoch = checkpoint['epoch']
print(
f"Successfully resumed from "
f"epoch {start_epoch + 1}"
)
print(
f"Starting LookThem V5 "
f"(Asymmetric Fusion) on {device}..."
)
# =========================================================
# 8. TRAINING LOOP
# =========================================================
for epoch in range(start_epoch, 20):
model.train()
total_loss = 0
correct = 0
total = 0
for data, target in train_loader:
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
scheduler.step()
acc = 100. * correct / total
current_lr = optimizer.param_groups[0]['lr']
print(
f"Epoch {epoch+1:02d}/20 | "
f"Train Loss: "
f"{total_loss / len(train_loader):.4f} | "
f"Train Acc: {acc:.2f}% | "
f"LR: {current_lr:.6f}"
)
# -----------------------------------------------------
# Periodic checkpoint save
# -----------------------------------------------------
if (epoch + 1) % 5 == 0:
torch.save({
'epoch': epoch + 1,
'model_state_dict':
model.state_dict(),
'optimizer_state_dict':
optimizer.state_dict(),
'scheduler_state_dict':
scheduler.state_dict(),
}, checkpoint_path)
print(
f"[CHECKPOINT] "
f"Epoch {epoch+1} saved successfully."
)
# =========================================================
# 9. FINAL VALIDATION
# =========================================================
model.eval()
test_loss = 0
test_correct = 0
test_total = 0
print("\nStarting final validation...")
with torch.no_grad():
for data, target in val_loader:
data = data.to(device)
target = target.to(device)
output = model(data)
loss = criterion(output, target)
test_loss += loss.item()
_, predicted = output.max(1)
test_total += target.size(0)
test_correct += predicted.eq(target).sum().item()
final_test_acc = (
100. * test_correct / test_total
)
print("=== FINAL LOOKTHEM V5 RESULTS ===")
print(
f"Test Loss: "
f"{test_loss / len(val_loader):.4f} | "
f"Test Accuracy: {final_test_acc:.2f}%"
)
# Save final trained weights
torch.save(
model.state_dict(),
"LookThem_V5_Final.pth"
)
print(
f"Training complete! "
f"Final model size: "
f"{os.path.getsize('LookThem_V5_Final.pth') / (1024*1024):.2f} MB"
)