| import random | |
| import gradio as gr | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from torch import nn | |
| from torchvision.models import mobilenet_v2, resnet18 | |
| from torchvision.transforms.functional import InterpolationMode | |
| datasets_n_classes = { | |
| "Imagenette": 10, | |
| "Imagewoof": 10, | |
| "Stanford_dogs": 120, | |
| } | |
| datasets_model_types = { | |
| "Imagenette": [ | |
| "base_200", | |
| "base_200+100", | |
| "synthetic_200", | |
| "augment_noisy_200", | |
| "augment_noisy_200+100", | |
| "augment_clean_200", | |
| ], | |
| "Imagewoof": [ | |
| "base_200", | |
| "base_200+100", | |
| "synthetic_200", | |
| "augment_noisy_200", | |
| "augment_noisy_200+100", | |
| "augment_clean_200", | |
| ], | |
| "Stanford_dogs": [ | |
| "base_200", | |
| "base_200+100", | |
| "synthetic_200", | |
| "augment_noisy_200", | |
| "augment_noisy_200+100", | |
| ], | |
| } | |
| model_arch = ["resnet18", "mobilenet_v2"] | |
| list_200 = [ | |
| "Original", | |
| "Synthetic", | |
| "Original + Synthetic (Noisy)", | |
| "Original + Synthetic (Clean)", | |
| ] | |
| list_200_100 = ["Base+100", "AugmentNoisy+100"] | |
| methods_map = { | |
| "200 Epochs": list_200, | |
| "200 Epochs on Original + 100": list_200_100, | |
| } | |
| label_map = dict() | |
| label_map["Imagenette (10 classes)"] = "Imagenette" | |
| label_map["Imagewoof (10 classes)"] = "Imagewoof" | |
| label_map["Stanford Dogs (120 classes)"] = "Stanford_dogs" | |
| label_map["ResNet-18"] = "resnet18" | |
| label_map["MobileNetV2"] = "mobilenet_v2" | |
| label_map["200 Epochs"] = "200" | |
| label_map["200 Epochs on Original + 100"] = "200+100" | |
| label_map["Original"] = "base" | |
| label_map["Synthetic"] = "synthetic" | |
| label_map["Original + Synthetic (Noisy)"] = "augment_noisy" | |
| label_map["Original + Synthetic (Clean)"] = "augment_clean" | |
| label_map["Base+100"] = "base" | |
| label_map["AugmentNoisy+100"] = "augment_noisy" | |
| dataset_models = dict() | |
| for dataset, n_classes in datasets_n_classes.items(): | |
| models = dict() | |
| for model_type in datasets_model_types[dataset]: | |
| for arch in model_arch: | |
| if arch == "resnet18": | |
| model = resnet18(weights=None, num_classes=n_classes) | |
| models[f"{arch}_{model_type}"] = ( | |
| model, | |
| f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", | |
| ) | |
| elif arch == "mobilenet_v2": | |
| model = mobilenet_v2(weights=None, num_classes=n_classes) | |
| models[f"{arch}_{model_type}"] = ( | |
| model, | |
| f"./models/{arch}/{dataset}/{dataset}_{model_type}.pth", | |
| ) | |
| else: | |
| raise ValueError(f"Model architecture unavailable: {arch}") | |
| dataset_models[dataset] = models | |
| def get_random_image(dataset, label_map=label_map) -> Image: | |
| dataset_root = f"./data/{label_map[dataset]}/val" | |
| dataset_img = torchvision.datasets.ImageFolder( | |
| dataset_root, | |
| transforms.Compose([transforms.PILToTensor()]), | |
| ) | |
| random_idx = random.randint(0, len(dataset_img) - 1) | |
| image, _ = dataset_img[random_idx] | |
| image = transforms.ToPILImage()(image) | |
| image = image.resize( | |
| (256, 256), | |
| ) | |
| return image | |
| def load_model(model_dict, model_name: str) -> nn.Module: | |
| model_name_lower = model_name.lower() | |
| if model_name_lower in model_dict: | |
| model = model_dict[model_name_lower][0] | |
| model_path = model_dict[model_name_lower][1] | |
| if torch.cuda.is_available(): | |
| checkpoint = torch.load(model_path) | |
| else: | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| if "setup" in checkpoint: | |
| if checkpoint["setup"]["distributed"]: | |
| torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( | |
| checkpoint["model"], "module." | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| return model | |
| else: | |
| raise ValueError( | |
| f"Model {model_name} is not available for image prediction. Please choose from {[name.capitalize() for name in model_dict.keys()]}." | |
| ) | |
| def postprocess_default(labels, output) -> dict: | |
| probabilities = nn.functional.softmax(output[0], dim=0) | |
| top_prob, top_catid = torch.topk(probabilities, 5) | |
| confidences = { | |
| labels[top_catid.tolist()[i]]: top_prob.tolist()[i] | |
| for i in range(top_prob.shape[0]) | |
| } | |
| return confidences | |
| def classify( | |
| input_image: Image, | |
| dataset_type: str, | |
| arch_type: str, | |
| methods: str, | |
| training_ds: str, | |
| dataset_models=dataset_models, | |
| label_map=label_map, | |
| ) -> dict: | |
| for i in [dataset_type, arch_type, methods, training_ds]: | |
| if i is None: | |
| raise ValueError("Please select all options.") | |
| dataset_type = label_map[dataset_type] | |
| arch_type = label_map[arch_type] | |
| methods = label_map[methods] | |
| training_ds = label_map[training_ds] | |
| preprocess_input = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| 256, | |
| interpolation=InterpolationMode.BILINEAR, | |
| antialias=True, | |
| ), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| if input_image is None: | |
| raise ValueError("No image was provided.") | |
| input_tensor: torch.Tensor = preprocess_input(input_image) | |
| input_batch = input_tensor.unsqueeze(0) | |
| model = load_model( | |
| dataset_models[dataset_type], f"{arch_type}_{training_ds}_{methods}" | |
| ) | |
| if torch.cuda.is_available(): | |
| input_batch = input_batch.to("cuda") | |
| model.to("cuda") | |
| model.eval() | |
| with torch.inference_mode(): | |
| output: torch.Tensor = model(input_batch) | |
| with open(f"./data/{dataset_type}.txt", "r") as f: | |
| labels = {i: line.strip() for i, line in enumerate(f.readlines())} | |
| return postprocess_default(labels, output) | |
| def update_methods(method, ds_type): | |
| if ds_type == "Stanford Dogs (120 classes)" and method == "200 Epochs": | |
| methods = list_200[:-1] | |
| else: | |
| methods = methods_map[method] | |
| return gr.update(choices=methods, value=None) | |
| def downloadModel( | |
| dataset_type, arch_type, methods, training_ds, dataset_models=dataset_models | |
| ): | |
| for i in [dataset_type, arch_type, methods, training_ds]: | |
| if i is None: | |
| return gr.update(label="Select Model", value=None) | |
| dataset_type = label_map[dataset_type] | |
| arch_type = label_map[arch_type] | |
| methods = label_map[methods] | |
| training_ds = label_map[training_ds] | |
| if f"{arch_type}_{training_ds}_{methods}" not in dataset_models[dataset_type]: | |
| return gr.update(label="Select Model", value=None) | |
| model_path = dataset_models[dataset_type][f"{arch_type}_{training_ds}_{methods}"][1] | |
| return gr.update( | |
| label=f"Download Model: '{dataset_type}_{arch_type}_{training_ds}_{methods}'", | |
| value=model_path, | |
| ) | |
| if __name__ == "__main__": | |
| with gr.Blocks(title="Generative Augmented Image Classifiers") as demo: | |
| gr.Markdown( | |
| """ | |
| # Generative Augmented Image Classifiers | |
| Main GitHub Repo: [Generative Data Augmentation](https://github.com/zhulinchng/generative-data-augmentation) | Generative Data Augmentation Demo: [Generative Data Augmented](https://huggingface.co/spaces/czl/generative-data-augmentation-demo). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| dataset_type = gr.Radio( | |
| choices=[ | |
| "Imagenette (10 classes)", | |
| "Imagewoof (10 classes)", | |
| "Stanford Dogs (120 classes)", | |
| ], | |
| label="Dataset", | |
| value="Imagenette (10 classes)", | |
| ) | |
| arch_type = gr.Radio( | |
| choices=["ResNet-18", "MobileNetV2"], | |
| label="Model Architecture", | |
| value="ResNet-18", | |
| interactive=True, | |
| ) | |
| methods = gr.Radio( | |
| label="Methods", | |
| choices=["200 Epochs", "200 Epochs on Original + 100"], | |
| interactive=True, | |
| value="200 Epochs", | |
| ) | |
| training_ds = gr.Radio( | |
| label="Training Dataset", | |
| choices=methods_map["200 Epochs"], | |
| interactive=True, | |
| value="Original", | |
| ) | |
| dataset_type.change( | |
| fn=update_methods, | |
| inputs=[methods, dataset_type], | |
| outputs=[training_ds], | |
| ) | |
| methods.change( | |
| fn=update_methods, | |
| inputs=[methods, dataset_type], | |
| outputs=[training_ds], | |
| ) | |
| random_image_output = gr.Image(type="pil", label="Image to Classify") | |
| with gr.Row(): | |
| generate_button = gr.Button("Sample Random Image") | |
| classify_button_random = gr.Button("Classify") | |
| with gr.Column(): | |
| output_label_random = gr.Label(num_top_classes=5) | |
| download_model = gr.DownloadButton( | |
| label=f"Download Model: '{label_map[dataset_type.value]}_{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}'", | |
| value=dataset_models[label_map[dataset_type.value]][ | |
| f"{label_map[arch_type.value]}_{label_map[training_ds.value]}_{label_map[methods.value]}" | |
| ][1], | |
| ) | |
| dataset_type.change( | |
| fn=downloadModel, | |
| inputs=[dataset_type, arch_type, methods, training_ds], | |
| outputs=[download_model], | |
| ) | |
| arch_type.change( | |
| fn=downloadModel, | |
| inputs=[dataset_type, arch_type, methods, training_ds], | |
| outputs=[download_model], | |
| ) | |
| methods.change( | |
| fn=downloadModel, | |
| inputs=[dataset_type, arch_type, methods, training_ds], | |
| outputs=[download_model], | |
| ) | |
| training_ds.change( | |
| fn=downloadModel, | |
| inputs=[dataset_type, arch_type, methods, training_ds], | |
| outputs=[download_model], | |
| ) | |
| gr.Markdown( | |
| """ | |
| This demo showcases the performance of image classifiers trained on various datasets as part of the project 'Improving Fine-Grained Image Classification Using Diffusion-Based Generated Synthetic Images' dissertation. | |
| View the models and files used in this demo [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/tree/main). | |
| Usage Instructions & Documentation [here](https://huggingface.co/spaces/czl/generative-augmented-classifiers/blob/main/README.md). | |
| """ | |
| ) | |
| generate_button.click( | |
| get_random_image, | |
| inputs=[dataset_type], | |
| outputs=random_image_output, | |
| ) | |
| classify_button_random.click( | |
| classify, | |
| inputs=[random_image_output, dataset_type, arch_type, methods, training_ds], | |
| outputs=output_label_random, | |
| ) | |
| demo.launch(show_error=True) | |