AshProg's picture
Update app.py
827d168 verified
"""
Gradio App for Bird Species Classification
Deployed on Hugging Face Spaces
"""
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import convnext_base
from PIL import Image
import json
# Load class names
with open('class_names.json', 'r') as f:
class_names = json.load(f)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model architecture (same as training)
def create_model(num_classes=200):
"""Create ConvNeXt model with same architecture as training"""
model = convnext_base(weights=None)
# Same classifier architecture as training
num_ftrs = model.classifier[2].in_features
model.classifier = nn.Sequential(
nn.Flatten(1),
nn.LayerNorm((num_ftrs,)),
nn.Dropout(0.6),
nn.Linear(num_ftrs, 512),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
return model
# Load the trained model
print("Loading model...")
model = create_model(num_classes=200)
# Load weights
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import convnext_base
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# Load class names
with open('class_names.json', 'r') as f:
class_names = json.load(f)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model architecture (same as training)
def create_model(num_classes=200):
"""Create ConvNeXt model with same architecture as training"""
model = convnext_base(weights=None)
# Same classifier architecture as training
num_ftrs = model.classifier[2].in_features
model.classifier = nn.Sequential(
nn.Flatten(1),
nn.LayerNorm((num_ftrs,)),
nn.Dropout(0.6),
nn.Linear(num_ftrs, 512),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
return model
# Download model from Hugging Face Model Hub
print("Downloading model from Hugging Face Model Hub...")
model_path = hf_hub_download(
repo_id="AshProg/bird-classifier-convnext",
filename="final_model.pth"
)
# Load the trained model
model = create_model(num_classes=200)
checkpoint = torch.load(model_path, map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
if 'val_acc' in checkpoint:
val_acc = checkpoint['val_acc']
print(f"Model loaded! Validation accuracy: {val_acc:.2f}%")
else:
model.load_state_dict(checkpoint)
print("Model loaded!")
model = model.to(device)
model.eval()
# Image preprocessing (same as validation transforms)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image):
"""
Make prediction on uploaded image
Args:
image: PIL Image
Returns:
dict: Top 5 predictions with confidence scores
"""
# Preprocess image
img_tensor = transform(image).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
# Format results
results = {}
for i in range(5):
class_id = top5_idx[0][i].item()
prob = top5_prob[0][i].item()
species_name = class_names.get(str(class_id), f"Class {class_id}")
results[species_name] = float(prob)
return results
# Create Gradio interface
title = "🐦 Bird Species Classification"
description = """
Upload an image of a bird and the model will predict the species!
**Model Details:**
- Architecture: ConvNeXt-Base (87M parameters)
- Dataset: CUB-200-2011 (200 bird species)
- Test Accuracy: 83.64%
- Average Per-Class Accuracy: 83.29%
Upload a clear image of a bird to get started!
"""
article = """
### About This Model
This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species.
**Key Features:**
- βœ… 200 bird species classification
- βœ… State-of-the-art ConvNeXt architecture
- βœ… 83.64% test accuracy
- βœ… Real-time inference
"""
examples = [
# You can add example images here if you have them
# ["examples/bird1.jpg"],
# ["examples/bird2.jpg"],
]
# Create interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Bird Image"),
outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
title=title,
description=description,
article=article,
examples=examples if examples else None,
theme=gr.themes.Soft(),
allow_flagging="never",
)
# Launch the app
if __name__ == "__main__":
iface.launch()