| | import torch |
| | from pathlib import Path |
| | from huggingface_hub import hf_hub_download |
| | from PIL import Image |
| | from torchvision import transforms |
| | from medmnist import INFO |
| | import gradio as gr |
| | import os |
| | import base64 |
| | from io import BytesIO |
| | from huggingface_hub import HfApi |
| | from datetime import datetime |
| | import io |
| |
|
| | from model import resnet18, resnet50 |
| |
|
| | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") |
| | AUTH_TOKEN = os.getenv("APP_TOKEN") |
| | DATASET_REPO = os.getenv("Dataset_repo") |
| | HF_TOKEN = os.getenv("HF_TOKEN") |
| | MODEL = os.getenv("Model_repo") |
| |
|
| | |
| | def load_model_from_hf( |
| | repo_id: str, |
| | filename: str, |
| | model_type: str, |
| | num_classes: int, |
| | in_channels: int, |
| | device: str, |
| | ) -> torch.nn.Module: |
| | """Load trained model from Hugging Face Hub. |
| | |
| | Args: |
| | repo_id: Hugging Face repository ID |
| | filename: Model checkpoint filename |
| | model_type: Type of model ('resnet18' or 'resnet50') |
| | num_classes: Number of output classes |
| | in_channels: Number of input channels |
| | device: Device to load model on |
| | |
| | Returns: |
| | Loaded model in eval mode |
| | """ |
| | print(f"Downloading model from Hugging Face: {repo_id}/{filename}") |
| | checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename) |
| |
|
| | |
| | if model_type == "resnet18": |
| | model = resnet18(num_classes=num_classes, in_channels=in_channels) |
| | else: |
| | model = resnet50(num_classes=num_classes, in_channels=in_channels) |
| |
|
| | |
| | checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| | model.to(device) |
| | model.eval() |
| |
|
| | return model |
| | |
| |
|
| |
|
| | |
| | def get_preprocessing_pipeline() -> transforms.Compose: |
| | """Get preprocessing pipeline for images.""" |
| | |
| | info = INFO["organamnist"] |
| | output_channels = info["n_channels"] |
| | |
| | mean = (0.5,) * output_channels |
| | std = (0.5,) * output_channels |
| | |
| | trans = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=mean, std=std), |
| | ]) |
| | |
| | return trans |
| | def get_class_labels(data_flag: str = "organamnist") -> list[str]: |
| | """Get class labels for MedMNIST dataset.""" |
| | |
| | info = INFO[data_flag] |
| | labels = info["label"] |
| | |
| | return labels |
| |
|
| | def save_image_to_hf_folder(image_path, prediction_label): |
| | """Upload image to HF dataset folder.""" |
| | api = HfApi() |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | |
| | |
| | metadata = f"prediction: {prediction_label}\ntimestamp: {timestamp}" |
| | metadata_path = f"{Path(image_path).stem}_metadata.txt" |
| | |
| | |
| | api.upload_file( |
| | path_or_fileobj=image_path, |
| | path_in_repo=f"uploads/{timestamp}_{Path(image_path).name}", |
| | repo_id=DATASET_REPO, |
| | repo_type="dataset", |
| | token=HF_TOKEN |
| | ) |
| | |
| | api.upload_file( |
| | path_or_fileobj=io.BytesIO(metadata.encode()), |
| | path_in_repo=f"uploads/{timestamp}_{Path(image_path).stem}_metadata.txt", |
| | repo_id=DATASET_REPO, |
| | repo_type="dataset", |
| | token=HF_TOKEN |
| | ) |
| |
|
| | def classify_images(images) -> str: |
| | """Classify images and return formatted HTML with embedded images.""" |
| | |
| | if images is None: |
| | return "<p>No images uploaded</p>" |
| | |
| | if isinstance(images, str): |
| | images = [images] |
| | |
| | html = "<div style='display: flex; flex-wrap: wrap; gap: 30px; padding: 20px; justify-content: center;'>" |
| | |
| | for image_path in images: |
| | |
| | img = Image.open(image_path).convert("L") |
| | input_tensor = preprocess(img).unsqueeze(0) |
| | |
| | with torch.no_grad(): |
| | output = model(input_tensor) |
| | probs = torch.nn.functional.softmax(output[0], dim=0) |
| | top_class = probs.argmax().item() |
| | |
| | label = class_labels[str(top_class)] |
| | |
| | filename = Path(image_path).name |
| | |
| | buffered = BytesIO() |
| | img.save(buffered, format="JPEG") |
| | img_str = base64.b64encode(buffered.getvalue()).decode() |
| | |
| | html += f""" |
| | <div style='border: 2px solid #ddd; padding: 15px; border-radius: 8px; background: #f9f9f9; width: 280px;'> |
| | <p style='font-size: 14px; color: #666; margin: 0 0 10px 0; text-align: center; font-weight: bold;'>{filename}</p> |
| | <img src='data:image/jpeg;base64,{img_str}' style='width: 250px; height: 250px; object-fit: contain; display: block; margin: 0 auto 10px;'> |
| | <p style='font-size: 18px; color: #0066cc; margin: 10px 0 0 0; text-align: center; font-weight: bold;'>{label}</p> |
| | </div> |
| | """ |
| | |
| | save_image_to_hf_folder(image_path, label) |
| | |
| | html += "</div>" |
| | |
| | return html |
| |
|
| | |
| |
|
| | |
| | model = load_model_from_hf( |
| | repo_id=MODEL, |
| | filename="resnet18_best.pth", |
| | model_type="resnet18", |
| | num_classes=11, |
| | in_channels=1, |
| | device=DEVICE, |
| | ) |
| | preprocess = get_preprocessing_pipeline() |
| | class_labels = get_class_labels() |
| | |
| | with gr.Blocks() as demo: |
| | |
| | gr.Markdown("<h1 style='text-align: center;'> MLOps project - MedMNIST dataset Image Classifier</h1>") |
| | |
| | gr.Markdown("This is a Gradio web application for MLOps course project. Given images are stored in our dataset. " \ |
| | "By uploading images you agrree that they will be stored by us and insures that they can be stored by us. " \ |
| | "If you somewhat passed the login and are not connected to the project, please do not upload any images. " ) |
| | |
| | with gr.Column(): |
| | |
| | gr.Markdown("<h2 style='text-align: center;'> Upload Images</h2>") |
| | |
| | images_input = gr.File(file_count="multiple", file_types=["image"], label="Upload Images") |
| | |
| | with gr.Row(): |
| | submit_btn = gr.Button("Classify") |
| | reset_btn = gr.Button("Reset") |
| | |
| | gr.Markdown("<h2 style='text-align: center;'> Results</h2>") |
| | |
| | output = gr.HTML(label="Results") |
| | |
| | def reset(): |
| | return None, "" |
| | |
| | submit_btn.click(classify_images, inputs=images_input, outputs=output) |
| | reset_btn.click(reset, outputs=[images_input, output]) |
| |
|
| |
|
| | |
| | server_name = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1") |
| | demo.launch( |
| | server_name=server_name, |
| | auth=[("user", AUTH_TOKEN)] if AUTH_TOKEN else None |
| | ) |