File size: 2,469 Bytes
b7e88e7 6841efe b7e88e7 3fc49d8 b7e88e7 3fc49d8 b7e88e7 1a69617 b7e88e7 1a69617 b7e88e7 3fc49d8 1a69617 6841efe 1a69617 6841efe 1a69617 6841efe b7e88e7 1a69617 b7e88e7 1a69617 b7e88e7 6841efe 1a69617 6841efe 1a69617 6841efe 1a69617 6841efe b7e88e7 6841efe b7e88e7 3fc49d8 1a69617 3fc49d8 b7e88e7 1a69617 b7e88e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
import gradio as gr
import librosa
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor
import os
# 1. CONFIGURATION
MODEL_ID = "facebook/wav2vec2-xls-r-300m"
QUANTIZED_MODEL_PATH = "quantized_model.pth"
# 2. LOAD MODEL
print("Loading model architecture...")
# Load architecture
model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_ID, num_labels=2)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID)
# Apply quantization structure (Must match how you saved it)
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
# Load weights
# Check if the quantized model file exists to avoid immediate crash
if os.path.exists(QUANTIZED_MODEL_PATH):
print("Loading quantized weights...")
model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH, map_location=torch.device('cpu')))
else:
print(f"Warning: {QUANTIZED_MODEL_PATH} not found. Using random weights (Model will not work correctly).")
model.eval()
# 3. PREDICTION FUNCTION
def predict_audio(audio_path):
# Gradio passes None if the user clears the input
if audio_path is None:
return "No Audio Provided"
try:
# Load and resample using librosa (handles filepath from upload OR mic)
speech_array, sr = librosa.load(audio_path, sr=16000)
inputs = feature_extractor(
speech_array,
sampling_rate=16000,
return_tensors="pt",
padding=True
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)
# Label 0 = Real, Label 1 = Deepfake
fake_prob = probs[0][1].item()
real_prob = probs[0][0].item()
return {
"Deepfake": fake_prob,
"Real": real_prob
}
except Exception as e:
return f"Error processing audio: {str(e)}"
# 4. CREATE INTERFACE
iface = gr.Interface(
fn=predict_audio,
inputs=gr.Audio(
sources=["upload", "microphone"], # <--- MODIFIED HERE
type="filepath", # Keep as filepath so librosa can load it
label="Upload or Record Audio"
),
outputs=gr.Label(num_top_classes=2),
title="Deepfake Audio Detection API",
description="Upload an audio file or record your voice to check if it's real or fake."
)
iface.launch() |