|
|
""" |
|
|
Gradio App for Bird Species Classification |
|
|
Deployed on Hugging Face Spaces |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms |
|
|
from torchvision.models import convnext_base |
|
|
from PIL import Image |
|
|
import json |
|
|
|
|
|
|
|
|
with open('class_names.json', 'r') as f: |
|
|
class_names = json.load(f) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
def create_model(num_classes=200): |
|
|
"""Create ConvNeXt model with same architecture as training""" |
|
|
model = convnext_base(weights=None) |
|
|
|
|
|
|
|
|
num_ftrs = model.classifier[2].in_features |
|
|
model.classifier = nn.Sequential( |
|
|
nn.Flatten(1), |
|
|
nn.LayerNorm((num_ftrs,)), |
|
|
nn.Dropout(0.6), |
|
|
nn.Linear(num_ftrs, 512), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(512, num_classes) |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
model = create_model(num_classes=200) |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms |
|
|
from torchvision.models import convnext_base |
|
|
from PIL import Image |
|
|
import json |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
with open('class_names.json', 'r') as f: |
|
|
class_names = json.load(f) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
def create_model(num_classes=200): |
|
|
"""Create ConvNeXt model with same architecture as training""" |
|
|
model = convnext_base(weights=None) |
|
|
|
|
|
|
|
|
num_ftrs = model.classifier[2].in_features |
|
|
model.classifier = nn.Sequential( |
|
|
nn.Flatten(1), |
|
|
nn.LayerNorm((num_ftrs,)), |
|
|
nn.Dropout(0.6), |
|
|
nn.Linear(num_ftrs, 512), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(512, num_classes) |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
print("Downloading model from Hugging Face Model Hub...") |
|
|
model_path = hf_hub_download( |
|
|
repo_id="AshProg/bird-classifier-convnext", |
|
|
filename="final_model.pth" |
|
|
) |
|
|
|
|
|
|
|
|
model = create_model(num_classes=200) |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
if 'val_acc' in checkpoint: |
|
|
val_acc = checkpoint['val_acc'] |
|
|
print(f"Model loaded! Validation accuracy: {val_acc:.2f}%") |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
print("Model loaded!") |
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def predict(image): |
|
|
""" |
|
|
Make prediction on uploaded image |
|
|
|
|
|
Args: |
|
|
image: PIL Image |
|
|
|
|
|
Returns: |
|
|
dict: Top 5 predictions with confidence scores |
|
|
""" |
|
|
|
|
|
img_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(img_tensor) |
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
|
|
|
|
|
|
|
top5_prob, top5_idx = torch.topk(probabilities, 5) |
|
|
|
|
|
|
|
|
results = {} |
|
|
for i in range(5): |
|
|
class_id = top5_idx[0][i].item() |
|
|
prob = top5_prob[0][i].item() |
|
|
species_name = class_names.get(str(class_id), f"Class {class_id}") |
|
|
results[species_name] = float(prob) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
title = "π¦ Bird Species Classification" |
|
|
description = """ |
|
|
Upload an image of a bird and the model will predict the species! |
|
|
|
|
|
**Model Details:** |
|
|
- Architecture: ConvNeXt-Base (87M parameters) |
|
|
- Dataset: CUB-200-2011 (200 bird species) |
|
|
- Test Accuracy: 83.64% |
|
|
- Average Per-Class Accuracy: 83.29% |
|
|
|
|
|
Upload a clear image of a bird to get started! |
|
|
""" |
|
|
|
|
|
article = """ |
|
|
### About This Model |
|
|
|
|
|
This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species. |
|
|
|
|
|
**Key Features:** |
|
|
- β
200 bird species classification |
|
|
- β
State-of-the-art ConvNeXt architecture |
|
|
- β
83.64% test accuracy |
|
|
- β
Real-time inference |
|
|
|
|
|
""" |
|
|
|
|
|
examples = [ |
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(type="pil", label="Upload Bird Image"), |
|
|
outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"), |
|
|
title=title, |
|
|
description=description, |
|
|
article=article, |
|
|
examples=examples if examples else None, |
|
|
theme=gr.themes.Soft(), |
|
|
allow_flagging="never", |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|