Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| from torch import optim | |
| from torch.nn import Module | |
| from torchvision import models, transforms | |
| from torchvision.datasets import ImageFolder | |
| from PIL import Image | |
| import numpy as np | |
| import onnxruntime | |
| import gradio as gr | |
| import json | |
| def get_image(x): | |
| return x.split(', ')[0] | |
| def to_numpy(tensor): | |
| return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
| # Transform image to ToTensor | |
| def transform_image(myarray): | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| image = Image.fromarray(np.uint8(myarray)).convert('RGB') | |
| image = transform(image).unsqueeze(0) | |
| return image | |
| f = open('imagenet_label.json',) | |
| label_map=json.load(f) | |
| f.close() | |
| # Load list of images for similarity | |
| sub_test_list = open('img_list.txt', 'r') | |
| sub_test_list = [i.strip() for i in sub_test_list] | |
| # Load images embedding for similarity | |
| embeddings = torch.load('embeddings.pt') | |
| # Configure | |
| options = onnxruntime.SessionOptions() | |
| options.intra_op_num_threads = 8 | |
| options.inter_op_num_threads = 8 | |
| # Load model | |
| PATH = 'model_onnx.onnx' | |
| ort_session = onnxruntime.InferenceSession(PATH, sess_options=options) | |
| input_name = ort_session.get_inputs()[0].name | |
| # predict multi-level classification | |
| def get_classification(img): | |
| image_tensor = transform_image(img) | |
| ort_inputs = {input_name: to_numpy(image_tensor)} | |
| x = ort_session.run(None, ort_inputs) | |
| predictions = torch.topk(torch.from_numpy(x[0]), k=5).indices.squeeze(0).tolist() | |
| result = {} | |
| for i in predictions: | |
| label = label_map[str(i)] | |
| prob = x[0][0, i].item() | |
| result[label] = prob | |
| return result | |
| iface = gr.Interface( | |
| get_classification, | |
| gr.inputs.Image(shape=(200, 200)), | |
| outputs="label", | |
| title = 'Universal Image Classification', | |
| description = "Imagenet classification from Mobilenetv3 converting to ONNX runtime", | |
| article = "Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>.", | |
| ) | |
| iface.launch() | |