DamageLensAI / scripts /prediction_helper.py
junaid17's picture
Upload 43 files
eef8873 verified
import os
import logging
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image, UnidentifiedImageError
from transformers import ConvNextModel, ConvNextImageProcessor
# ---------------- LOGGING SETUP ----------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
)
logger = logging.getLogger(__name__)
# ---------------- RESNET MODEL ----------------
class Car_Classifier_Resnet(nn.Module):
def __init__(self, num_classes):
super().__init__()
logger.info("Initializing ResNet18 architecture...")
self.model = models.resnet18(weights="DEFAULT")
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.layer3.parameters():
param.requires_grad = True
for param in self.model.layer4.parameters():
param.requires_grad = True
self.model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(self.model.fc.in_features, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
logger.info("ResNet architecture initialized successfully.")
def forward(self, x):
return self.model(x)
class ResnetCarDamagePredictor:
def __init__(self, checkpoint_path, class_map):
logger.info("Initializing ResNet predictor...")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.class_map = class_map
logger.info(f"Using device for ResNet: {self.device}")
self.test_transforms = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
try:
self.model = Car_Classifier_Resnet(num_classes=len(class_map))
logger.info(f"Loading ResNet checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
state_dict = checkpoint.get("model_state_dict", checkpoint)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
logger.info("ResNet model loaded successfully.")
except Exception as e:
logger.exception("Failed to load ResNet model.")
raise RuntimeError(f"Failed to load ResNet model: {str(e)}")
def resnet_predict(self, image_input):
logger.info("Starting ResNet prediction...")
try:
if isinstance(image_input, str):
logger.info(f"Loading image from file path: {image_input}")
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
logger.info("Using PIL image input.")
image = image_input.convert("RGB")
else:
raise TypeError("image_input must be a file path or PIL.Image")
image = self.test_transforms(image)
image = image.unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(image)
probs = torch.nn.functional.softmax(outputs, dim=1)[0]
class_probs = {
self.class_map[i]: float(probs[i].item())
for i in range(len(self.class_map))
}
sorted_probs = dict(
sorted(class_probs.items(), key=lambda x: x[1], reverse=True)
)
logger.info("ResNet prediction completed successfully.")
return sorted_probs
except UnidentifiedImageError:
logger.error("Invalid image file provided to ResNet predictor.")
raise ValueError("Invalid image file provided")
except Exception as e:
logger.exception("ResNet prediction failed.")
raise RuntimeError(f"ResNet prediction failed: {str(e)}")
# ---------------- FUSION MODEL ----------------
class FusionClassifier(nn.Module):
def __init__(self, num_classes, convnext_model_name="facebook/convnext-small-224"):
super().__init__()
logger.info("Initializing Fusion model architecture...")
eff = models.efficientnet_v2_s(
weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
)
for param in eff.parameters():
param.requires_grad = False
for param in eff.features[5].parameters():
param.requires_grad = True
for param in eff.features[6].parameters():
param.requires_grad = True
for param in eff.features[7].parameters():
param.requires_grad = True
self.eff_features = eff.features
self.eff_avgpool = eff.avgpool
self.eff_out_dim = eff.classifier[1].in_features
logger.info("Loading ConvNeXt backbone...")
cnx = ConvNextModel.from_pretrained(convnext_model_name)
for param in cnx.parameters():
param.requires_grad = False
for param in cnx.encoder.stages[2].parameters():
param.requires_grad = True
for param in cnx.encoder.stages[3].parameters():
param.requires_grad = True
for param in cnx.layernorm.parameters():
param.requires_grad = True
self.cnx_backbone = cnx
self.cnx_out_dim = 768
fused_dim = self.eff_out_dim + self.cnx_out_dim
self.fusion_head = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(fused_dim, 512),
nn.LayerNorm(512),
nn.GELU(),
nn.Dropout(p=0.3),
nn.Linear(512, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(p=0.2),
nn.Linear(256, num_classes)
)
logger.info("Fusion architecture initialized successfully.")
def forward(self, pixel_values_eff, pixel_values_cnx):
x_eff = self.eff_features(pixel_values_eff)
x_eff = self.eff_avgpool(x_eff)
x_eff = torch.flatten(x_eff, 1)
cnx_out = self.cnx_backbone(
pixel_values=pixel_values_cnx,
return_dict=True
)
x_cnx = cnx_out.pooler_output
fused = torch.cat([x_eff, x_cnx], dim=1)
return self.fusion_head(fused)
class FusionCarDamagePredictor:
def __init__(self, checkpoint_path, class_map, convnext_model_name="facebook/convnext-small-224"):
logger.info("Initializing Fusion predictor...")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.class_map = class_map
logger.info(f"Using device for Fusion: {self.device}")
self.eff_normalize = transforms.Compose([
transforms.Resize((260, 260)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]
)
])
logger.info("Loading ConvNeXt image processor...")
self.convnext_processor = ConvNextImageProcessor.from_pretrained(
convnext_model_name
)
try:
self.model = FusionClassifier(
num_classes=len(class_map),
convnext_model_name=convnext_model_name
)
logger.info(f"Loading Fusion checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
state_dict = checkpoint.get("model_state_dict", checkpoint)
first_tensor = next(iter(state_dict.values()))
if first_tensor.dtype == torch.float16:
logger.info("FP16 checkpoint detected. Converting model to half precision.")
self.model = self.model.half()
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
logger.info("Fusion model loaded successfully.")
except Exception as e:
logger.exception("Failed to load Fusion model.")
raise RuntimeError(f"Failed to load Fusion model: {str(e)}")
def predict(self, image_input):
logger.info("Starting Fusion prediction...")
try:
if isinstance(image_input, str):
logger.info(f"Loading image from file path: {image_input}")
image = Image.open(image_input).convert("RGB")
elif isinstance(image_input, Image.Image):
logger.info("Using PIL image input.")
image = image_input.convert("RGB")
else:
raise TypeError("image_input must be a file path or PIL.Image")
pixel_eff = self.eff_normalize(image)
pixel_eff = pixel_eff.unsqueeze(0).to(self.device)
inputs_cnx = self.convnext_processor(
images=image,
return_tensors="pt"
)
pixel_cnx = inputs_cnx["pixel_values"].to(self.device)
if next(self.model.parameters()).dtype == torch.float16:
logger.info("Converting input tensors to FP16.")
pixel_eff = pixel_eff.half()
pixel_cnx = pixel_cnx.half()
with torch.no_grad():
logits = self.model(pixel_eff, pixel_cnx)
probs = torch.nn.functional.softmax(logits, dim=1)[0]
class_probs = {
self.class_map[i]: float(probs[i].item())
for i in range(len(self.class_map))
}
sorted_probs = dict(
sorted(class_probs.items(), key=lambda x: x[1], reverse=True)
)
logger.info("Fusion prediction completed successfully.")
return sorted_probs
except UnidentifiedImageError:
logger.error("Invalid image file provided to Fusion predictor.")
raise ValueError("Invalid image file provided")
except Exception as e:
logger.exception("Fusion prediction failed.")
raise RuntimeError(f"Fusion prediction failed: {str(e)}")