| import torch
|
| from PIL import Image
|
| from torchvision import transforms
|
| from model import load_model
|
| import json
|
| import os
|
|
|
| class GarbageClassifier:
|
| def __init__(self, model_dir="."):
|
| """Initialize the garbage classifier"""
|
|
|
| with open(os.path.join(model_dir, "config.json"), "r") as f:
|
| self.config = json.load(f)
|
|
|
|
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
| model_path = os.path.join(model_dir, "pytorch_model.bin")
|
| self.model = load_model(model_path, self.device)
|
|
|
|
|
| mean = self.config["normalization"]["mean"]
|
| std = self.config["normalization"]["std"]
|
| size = tuple(self.config["input_size"])
|
|
|
| self.transform = transforms.Compose([
|
| transforms.Resize(size),
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean, std)
|
| ])
|
|
|
| self.class_names = self.config["class_names"]
|
|
|
| def predict(self, image_path):
|
| """
|
| Predict the class of a garbage image
|
|
|
| Args:
|
| image_path: Path to the image file
|
|
|
| Returns:
|
| dict: Contains 'class', 'confidence', and 'all_probabilities'
|
| """
|
|
|
| image = Image.open(image_path).convert('RGB')
|
| image_tensor = self.transform(image).unsqueeze(0).to(self.device)
|
|
|
|
|
| with torch.no_grad():
|
| outputs = self.model(image_tensor)
|
| probabilities = torch.nn.functional.softmax(outputs, dim=1)
|
| confidence, predicted = torch.max(probabilities, 1)
|
|
|
|
|
| predicted_class = self.class_names[predicted.item()]
|
| confidence_score = confidence.item()
|
| all_probs = {
|
| self.class_names[i]: probabilities[0][i].item()
|
| for i in range(len(self.class_names))
|
| }
|
|
|
| return {
|
| "class": predicted_class,
|
| "confidence": confidence_score,
|
| "all_probabilities": all_probs
|
| }
|
|
|
|
|
| if __name__ == "__main__":
|
| classifier = GarbageClassifier(".")
|
| result = classifier.predict("path/to/image.jpg")
|
| print(f"Predicted class: {result['class']}")
|
| print(f"Confidence: {result['confidence']:.2%}")
|
|
|