File size: 5,047 Bytes
bb672b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
"""
Gradio App for Bird Species Classification
Deployed on Hugging Face Spaces
"""
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import convnext_base
from PIL import Image
import json
# Load class names
with open('class_names.json', 'r') as f:
class_names = json.load(f)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model architecture (same as training)
def create_model(num_classes=200):
"""Create ConvNeXt model with same architecture as training"""
model = convnext_base(weights=None)
# Same classifier architecture as training
num_ftrs = model.classifier[2].in_features
model.classifier = nn.Sequential(
nn.Flatten(1),
nn.LayerNorm((num_ftrs,)),
nn.Dropout(0.6),
nn.Linear(num_ftrs, 512),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
return model
# Load the trained model
print("Loading model...")
model = create_model(num_classes=200)
# Load weights
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import convnext_base
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# Load class names
with open('class_names.json', 'r') as f:
class_names = json.load(f)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model architecture (same as training)
def create_model(num_classes=200):
"""Create ConvNeXt model with same architecture as training"""
model = convnext_base(weights=None)
# Same classifier architecture as training
num_ftrs = model.classifier[2].in_features
model.classifier = nn.Sequential(
nn.Flatten(1),
nn.LayerNorm((num_ftrs,)),
nn.Dropout(0.6),
nn.Linear(num_ftrs, 512),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
return model
# Download model from Hugging Face Model Hub
print("Downloading model from Hugging Face Model Hub...")
model_path = hf_hub_download(
repo_id="AshProg/bird-classifier-convnext",
filename="final_model.pth"
)
# Load the trained model
model = create_model(num_classes=200)
checkpoint = torch.load(model_path, map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
if 'val_acc' in checkpoint:
val_acc = checkpoint['val_acc']
print(f"Model loaded! Validation accuracy: {val_acc:.2f}%")
else:
model.load_state_dict(checkpoint)
print("Model loaded!")
model = model.to(device)
model.eval()
# Image preprocessing (same as validation transforms)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image):
"""
Make prediction on uploaded image
Args:
image: PIL Image
Returns:
dict: Top 5 predictions with confidence scores
"""
# Preprocess image
img_tensor = transform(image).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
# Format results
results = {}
for i in range(5):
class_id = top5_idx[0][i].item()
prob = top5_prob[0][i].item()
species_name = class_names.get(str(class_id), f"Class {class_id}")
results[species_name] = float(prob)
return results
# Create Gradio interface
title = "π¦ Bird Species Classification"
description = """
Upload an image of a bird and the model will predict the species!
**Model Details:**
- Architecture: ConvNeXt-Base (87M parameters)
- Dataset: CUB-200-2011 (200 bird species)
- Test Accuracy: 83.64%
- Average Per-Class Accuracy: 83.29%
Upload a clear image of a bird to get started!
"""
article = """
### About This Model
This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species.
**Key Features:**
- β
200 bird species classification
- β
State-of-the-art ConvNeXt architecture
- β
83.64% test accuracy
- β
Real-time inference
"""
examples = [
# You can add example images here if you have them
# ["examples/bird1.jpg"],
# ["examples/bird2.jpg"],
]
# Create interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Bird Image"),
outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
title=title,
description=description,
article=article,
examples=examples if examples else None,
theme=gr.themes.Soft(),
allow_flagging="never",
)
# Launch the app
if __name__ == "__main__":
iface.launch()
|