""" Gradio App for Bird Species Classification Deployed on Hugging Face Spaces """ import gradio as gr import torch import torch.nn as nn from torchvision import transforms from torchvision.models import convnext_base from PIL import Image import json # Load class names with open('class_names.json', 'r') as f: class_names = json.load(f) # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Create model architecture (same as training) def create_model(num_classes=200): """Create ConvNeXt model with same architecture as training""" model = convnext_base(weights=None) # Same classifier architecture as training num_ftrs = model.classifier[2].in_features model.classifier = nn.Sequential( nn.Flatten(1), nn.LayerNorm((num_ftrs,)), nn.Dropout(0.6), nn.Linear(num_ftrs, 512), nn.GELU(), nn.Dropout(0.5), nn.Linear(512, num_classes) ) return model # Load the trained model print("Loading model...") model = create_model(num_classes=200) # Load weights import gradio as gr import torch import torch.nn as nn from torchvision import transforms from torchvision.models import convnext_base from PIL import Image import json from huggingface_hub import hf_hub_download # Load class names with open('class_names.json', 'r') as f: class_names = json.load(f) # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Create model architecture (same as training) def create_model(num_classes=200): """Create ConvNeXt model with same architecture as training""" model = convnext_base(weights=None) # Same classifier architecture as training num_ftrs = model.classifier[2].in_features model.classifier = nn.Sequential( nn.Flatten(1), nn.LayerNorm((num_ftrs,)), nn.Dropout(0.6), nn.Linear(num_ftrs, 512), nn.GELU(), nn.Dropout(0.5), nn.Linear(512, num_classes) ) return model # Download model from Hugging Face Model Hub print("Downloading model from Hugging Face Model Hub...") model_path = hf_hub_download( repo_id="AshProg/bird-classifier-convnext", filename="final_model.pth" ) # Load the trained model model = create_model(num_classes=200) checkpoint = torch.load(model_path, map_location=device) if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) if 'val_acc' in checkpoint: val_acc = checkpoint['val_acc'] print(f"Model loaded! Validation accuracy: {val_acc:.2f}%") else: model.load_state_dict(checkpoint) print("Model loaded!") model = model.to(device) model.eval() # Image preprocessing (same as validation transforms) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict(image): """ Make prediction on uploaded image Args: image: PIL Image Returns: dict: Top 5 predictions with confidence scores """ # Preprocess image img_tensor = transform(image).unsqueeze(0).to(device) # Make prediction with torch.no_grad(): outputs = model(img_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) # Get top 5 predictions top5_prob, top5_idx = torch.topk(probabilities, 5) # Format results results = {} for i in range(5): class_id = top5_idx[0][i].item() prob = top5_prob[0][i].item() species_name = class_names.get(str(class_id), f"Class {class_id}") results[species_name] = float(prob) return results # Create Gradio interface title = "🐦 Bird Species Classification" description = """ Upload an image of a bird and the model will predict the species! **Model Details:** - Architecture: ConvNeXt-Base (87M parameters) - Dataset: CUB-200-2011 (200 bird species) - Test Accuracy: 83.64% - Average Per-Class Accuracy: 83.29% Upload a clear image of a bird to get started! """ article = """ ### About This Model This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species. **Key Features:** - ✅ 200 bird species classification - ✅ State-of-the-art ConvNeXt architecture - ✅ 83.64% test accuracy - ✅ Real-time inference """ examples = [ # You can add example images here if you have them # ["examples/bird1.jpg"], # ["examples/bird2.jpg"], ] # Create interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Bird Image"), outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"), title=title, description=description, article=article, examples=examples if examples else None, theme=gr.themes.Soft(), allow_flagging="never", ) # Launch the app if __name__ == "__main__": iface.launch()