| | import os |
| | os.system("pip install gradio==2.7.5.2") |
| | from torchvision import transforms |
| | import torch |
| | import urllib |
| | from PIL import Image |
| | import gradio as gr |
| | import torch |
| |
|
| | |
| | torch.hub.download_url_to_file('https://static.scientificamerican.com/sciam/cache/file/7A715AD8-449D-4B5A-ABA2C5D92D9B5A21_source.png', 'bird.png') |
| |
|
| | model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, |
| | **{'topN': 6, 'device':'cpu', 'num_classes': 200}) |
| |
|
| | transform_test = transforms.Compose([ |
| | transforms.Resize((600, 600), Image.BILINEAR), |
| | transforms.CenterCrop((448, 448)), |
| | |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| | ]) |
| |
|
| |
|
| | model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, **{'topN': 6, 'device':'cpu', 'num_classes': 200}) |
| |
|
| | def birds(img): |
| | scaled_img = transform_test(img) |
| | torch_images = scaled_img.unsqueeze(0) |
| |
|
| | with torch.no_grad(): |
| | top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images) |
| |
|
| | _, predict = torch.max(concat_logits, 1) |
| | pred_id = predict.item() |
| | return model.bird_classes[pred_id].split('.')[1] |
| |
|
| | inputs = gr.inputs.Image(type='pil', label="Original Image") |
| | outputs = gr.outputs.Textbox(label="bird class") |
| |
|
| | title = "ntsnet" |
| | description = "demo for ntsnet to classify birds. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." |
| | article = "<p style='text-align: center'><a href='http://artelab.dista.uninsubria.it/res/research/papers/2019/2019-IVCNZ-Nawaz-Birds.pdf'>Are These Birds Similar: Learning Branched Networks for Fine-grained Representations</a> | <a href='https://github.com/nicolalandro/ntsnet-cub200'>Github Repo</a></p>" |
| |
|
| | examples = [ |
| | ['bird.png'] |
| | ] |
| | gr.Interface(birds, inputs, outputs, title=title, description=description, |
| | article=article, examples=examples, analytics_enabled=False).launch(cache_examples=True,enable_queue=True) |