File size: 2,469 Bytes
b7e88e7
 
 
 
6841efe
b7e88e7
 
 
3fc49d8
b7e88e7
3fc49d8
b7e88e7
1a69617
b7e88e7
 
 
1a69617
b7e88e7
 
 
 
3fc49d8
1a69617
6841efe
 
1a69617
6841efe
1a69617
6841efe
b7e88e7
 
1a69617
b7e88e7
1a69617
b7e88e7
 
 
6841efe
1a69617
6841efe
 
 
 
 
 
1a69617
6841efe
 
 
 
 
 
1a69617
 
 
 
 
 
 
 
6841efe
 
b7e88e7
6841efe
b7e88e7
 
3fc49d8
1a69617
 
 
3fc49d8
b7e88e7
1a69617
 
b7e88e7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import gradio as gr
import librosa
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
import os

# 1. CONFIGURATION
MODEL_ID = "facebook/wav2vec2-xls-r-300m"
QUANTIZED_MODEL_PATH = "quantized_model.pth"

# 2. LOAD MODEL
print("Loading model architecture...")
# Load architecture
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_ID, num_labels=2)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID)

# Apply quantization structure (Must match how you saved it)
model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

# Load weights
# Check if the quantized model file exists to avoid immediate crash
if os.path.exists(QUANTIZED_MODEL_PATH):
    print("Loading quantized weights...")
    model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH, map_location=torch.device('cpu')))
else:
    print(f"Warning: {QUANTIZED_MODEL_PATH} not found. Using random weights (Model will not work correctly).")

model.eval()

# 3. PREDICTION FUNCTION
def predict_audio(audio_path):
    # Gradio passes None if the user clears the input
    if audio_path is None:
        return "No Audio Provided"
    
    try:
        # Load and resample using librosa (handles filepath from upload OR mic)
        speech_array, sr = librosa.load(audio_path, sr=16000)
        
        inputs = feature_extractor(
            speech_array, 
            sampling_rate=16000, 
            return_tensors="pt", 
            padding=True
        )
        
        with torch.no_grad():
            logits = model(**inputs).logits
            probs = torch.nn.functional.softmax(logits, dim=-1)
            
        # Label 0 = Real, Label 1 = Deepfake 
        fake_prob = probs[0][1].item()
        real_prob = probs[0][0].item()
        
        return {
            "Deepfake": fake_prob, 
            "Real": real_prob
        }
    except Exception as e:
        return f"Error processing audio: {str(e)}"

# 4. CREATE INTERFACE
iface = gr.Interface(
    fn=predict_audio,
    inputs=gr.Audio(
        sources=["upload", "microphone"], # <--- MODIFIED HERE
        type="filepath",                  # Keep as filepath so librosa can load it
        label="Upload or Record Audio"
    ), 
    outputs=gr.Label(num_top_classes=2),
    title="Deepfake Audio Detection API",
    description="Upload an audio file or record your voice to check if it's real or fake."
)

iface.launch()