| import numpy as np |
| import tensorflow as tf |
| import gradio as gr |
| from huggingface_hub import from_pretrained_keras |
|
|
| teacher_model = from_pretrained_keras("keras-io/consistency_training_with_supervision_teacher_model") |
|
|
| student_model = from_pretrained_keras("keras-io/consistency_training_with_supervision_student_model") |
|
|
| class_names = [ |
| "Airplane", |
| "Automobile", |
| "Bird", |
| "Cat", |
| "Deer", |
| "Dog", |
| "Frog", |
| "Horse", |
| "Ship", |
| "Truck", |
| ] |
|
|
| examples = [ |
| ['./aeroplane.png'], |
| ['./horse.png'], |
| ['./ship.png'], |
| ['./truck.png'] |
| ] |
|
|
| IMG_SIZE = 72 |
|
|
| def teacher_model_output(image_tensor): |
| predictions = teacher_model.predict(np.expand_dims((image_tensor), axis=0)) |
| predictions = np.squeeze(predictions) |
| predictions = np.argmax(predictions) |
| predicted_label = class_names[predictions.item()] |
| return str(predicted_label) |
| |
| def student_model_output(image_tensor): |
| predictions = student_model.predict(np.expand_dims((image_tensor), axis=0)) |
| predictions = np.squeeze(predictions) |
| predictions = np.argmax(predictions) |
| predicted_label = class_names[predictions.item()] |
| return str(predicted_label) |
|
|
| def infer(input_image): |
| image_tensor = tf.convert_to_tensor(input_image) |
| image_tensor.set_shape([None, None, 3]) |
| image_tensor = tf.image.resize(image_tensor, (IMG_SIZE, IMG_SIZE)) |
| return teacher_model_output(image_tensor), student_model_output(image_tensor) |
| |
| input = gr.inputs.Image(shape=(IMG_SIZE, IMG_SIZE)) |
| output = [gr.outputs.Label(label = "Teacher Model Output"), gr.outputs.Label(label = "Student Model Output")] |
|
|
| title = "Image Classification using Consistency training with supervision" |
| description = "Upload an image or select from examples to classify it.<br>The allowed classes are - Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck.<br><p><b>Teacher Model Repo - https://huggingface.co/keras-io/consistency_training_with_supervision_teacher_model</b> <br><b> Student Model Repo - https://huggingface.co/keras-io/consistency_training_with_supervision_student_model </b><br><b>Keras Example - https://keras.io/examples/vision/consistency_training/</b></p>" |
|
|
|
|
| article = "<div style='text-align: center;'><a href='https://twitter.com/_Blazer_007' target='_blank'>Space by Vivek Rai</a><br><a href='https://twitter.com/RisingSayak' target='_blank'>Keras example by Sayak Paul</a></div>" |
|
|
| gr_interface = gr.Interface( |
| infer, |
| input, |
| output, |
| examples=examples, |
| allow_flagging=False, |
| analytics_enabled=False, |
| title=title, |
| description=description, |
| article=article).launch(enable_queue=True, debug=True) |