File size: 5,047 Bytes
bb672b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Gradio App for Bird Species Classification
Deployed on Hugging Face Spaces
"""

import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import convnext_base
from PIL import Image
import json

# Load class names
with open('class_names.json', 'r') as f:
    class_names = json.load(f)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create model architecture (same as training)
def create_model(num_classes=200):
    """Create ConvNeXt model with same architecture as training"""
    model = convnext_base(weights=None)
    
    # Same classifier architecture as training
    num_ftrs = model.classifier[2].in_features
    model.classifier = nn.Sequential(
        nn.Flatten(1),
        nn.LayerNorm((num_ftrs,)),
        nn.Dropout(0.6),
        nn.Linear(num_ftrs, 512),
        nn.GELU(),
        nn.Dropout(0.5),
        nn.Linear(512, num_classes)
    )
    
    return model

# Load the trained model
print("Loading model...")
model = create_model(num_classes=200)

# Load weights
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import convnext_base
from PIL import Image
import json
from huggingface_hub import hf_hub_download

# Load class names
with open('class_names.json', 'r') as f:
    class_names = json.load(f)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create model architecture (same as training)
def create_model(num_classes=200):
    """Create ConvNeXt model with same architecture as training"""
    model = convnext_base(weights=None)
    
    # Same classifier architecture as training
    num_ftrs = model.classifier[2].in_features
    model.classifier = nn.Sequential(
        nn.Flatten(1),
        nn.LayerNorm((num_ftrs,)),
        nn.Dropout(0.6),
        nn.Linear(num_ftrs, 512),
        nn.GELU(),
        nn.Dropout(0.5),
        nn.Linear(512, num_classes)
    )
    
    return model

# Download model from Hugging Face Model Hub
print("Downloading model from Hugging Face Model Hub...")
model_path = hf_hub_download(
    repo_id="AshProg/bird-classifier-convnext",
    filename="final_model.pth"
)

# Load the trained model
model = create_model(num_classes=200)
checkpoint = torch.load(model_path, map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    if 'val_acc' in checkpoint:
        val_acc = checkpoint['val_acc']
        print(f"Model loaded! Validation accuracy: {val_acc:.2f}%")
else:
    model.load_state_dict(checkpoint)
    print("Model loaded!")

model = model.to(device)
model.eval()

# Image preprocessing (same as validation transforms)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def predict(image):
    """
    Make prediction on uploaded image
    
    Args:
        image: PIL Image
        
    Returns:
        dict: Top 5 predictions with confidence scores
    """
    # Preprocess image
    img_tensor = transform(image).unsqueeze(0).to(device)
    
    # Make prediction
    with torch.no_grad():
        outputs = model(img_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        
        # Get top 5 predictions
        top5_prob, top5_idx = torch.topk(probabilities, 5)
        
    # Format results
    results = {}
    for i in range(5):
        class_id = top5_idx[0][i].item()
        prob = top5_prob[0][i].item()
        species_name = class_names.get(str(class_id), f"Class {class_id}")
        results[species_name] = float(prob)
    
    return results

# Create Gradio interface
title = "🐦 Bird Species Classification"
description = """
Upload an image of a bird and the model will predict the species!

**Model Details:**
- Architecture: ConvNeXt-Base (87M parameters)
- Dataset: CUB-200-2011 (200 bird species)
- Test Accuracy: 83.64%
- Average Per-Class Accuracy: 83.29%

Upload a clear image of a bird to get started!
"""

article = """
### About This Model

This bird classifier was trained on the CUB-200-2011 dataset containing 200 North American bird species.

**Key Features:**
- βœ… 200 bird species classification
- βœ… State-of-the-art ConvNeXt architecture
- βœ… 83.64% test accuracy
- βœ… Real-time inference

"""

examples = [
    # You can add example images here if you have them
    # ["examples/bird1.jpg"],
    # ["examples/bird2.jpg"],
]

# Create interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Bird Image"),
    outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
    title=title,
    description=description,
    article=article,
    examples=examples if examples else None,
    theme=gr.themes.Soft(),
    allow_flagging="never",
)

# Launch the app
if __name__ == "__main__":
    iface.launch()