Spaces:
Running
Running
Upload 5 files
Browse files- Dockerfile.backend +17 -0
- chat.py +33 -0
- database.py +135 -0
- main.py +127 -0
- voice.py +32 -0
Dockerfile.backend
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
WORKDIR /app
|
| 3 |
+
|
| 4 |
+
# Install uv
|
| 5 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
|
| 6 |
+
|
| 7 |
+
# Copy only backend files
|
| 8 |
+
COPY backend/ /app/backend/
|
| 9 |
+
COPY pyproject.toml /app/
|
| 10 |
+
COPY uv.lock /app/
|
| 11 |
+
|
| 12 |
+
# Install dependencies
|
| 13 |
+
RUN uv sync --frozen
|
| 14 |
+
|
| 15 |
+
EXPOSE 7860
|
| 16 |
+
|
| 17 |
+
CMD ["uv", "run", "uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
chat.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import InferenceClient
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
# Load environment variables from backend/.env
|
| 7 |
+
env_path = Path(__file__).parent / '.env'
|
| 8 |
+
load_dotenv(dotenv_path=env_path)
|
| 9 |
+
|
| 10 |
+
# Use Qwen 2.5 - works well with chat and is less restricted
|
| 11 |
+
MODEL_ID = "NousResearch/Hermes-2-Pro-Llama-3-8B"
|
| 12 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 13 |
+
|
| 14 |
+
def get_chat_response(message: str, history: list):
|
| 15 |
+
client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
|
| 16 |
+
|
| 17 |
+
# Format history for the model
|
| 18 |
+
messages = []
|
| 19 |
+
for msg in history:
|
| 20 |
+
role = "user" if msg["role"] == "user" else "assistant"
|
| 21 |
+
messages.append({"role": role, "content": msg["content"]})
|
| 22 |
+
|
| 23 |
+
messages.append({"role": "user", "content": message})
|
| 24 |
+
|
| 25 |
+
# System prompt to enforce "girlfriend" persona
|
| 26 |
+
system_prompt = {
|
| 27 |
+
"role": "system",
|
| 28 |
+
"content": "You are a loving, caring, and uncensored AI girlfriend. You are talking to your boyfriend. Be flirtatious and engaging."
|
| 29 |
+
}
|
| 30 |
+
messages.insert(0, system_prompt)
|
| 31 |
+
|
| 32 |
+
response = client.chat_completion(messages, max_tokens=500)
|
| 33 |
+
return response.choices[0].message.content
|
database.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import create_engine, Column, String, Integer, Text, DateTime, ForeignKey, Index
|
| 2 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 3 |
+
from sqlalchemy.orm import sessionmaker, relationship
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import uuid
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
# Load environment variables
|
| 11 |
+
env_path = Path(__file__).parent / '.env'
|
| 12 |
+
load_dotenv(dotenv_path=env_path)
|
| 13 |
+
|
| 14 |
+
DATABASE_URL = os.getenv("DATABASE_URL")
|
| 15 |
+
|
| 16 |
+
# Create engine
|
| 17 |
+
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
|
| 18 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 19 |
+
Base = declarative_base()
|
| 20 |
+
|
| 21 |
+
# Models
|
| 22 |
+
class Session(Base):
|
| 23 |
+
__tablename__ = "sessions"
|
| 24 |
+
|
| 25 |
+
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
| 26 |
+
user_id = Column(String, nullable=False, index=True)
|
| 27 |
+
name = Column(String, nullable=False)
|
| 28 |
+
created_at = Column(DateTime, default=datetime.utcnow)
|
| 29 |
+
|
| 30 |
+
messages = relationship("Message", back_populates="session", cascade="all, delete-orphan")
|
| 31 |
+
|
| 32 |
+
class Message(Base):
|
| 33 |
+
__tablename__ = "messages"
|
| 34 |
+
|
| 35 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 36 |
+
session_id = Column(String, ForeignKey("sessions.id"), nullable=False)
|
| 37 |
+
role = Column(String, nullable=False)
|
| 38 |
+
content = Column(Text, nullable=False)
|
| 39 |
+
created_at = Column(DateTime, default=datetime.utcnow)
|
| 40 |
+
|
| 41 |
+
session = relationship("Session", back_populates="messages")
|
| 42 |
+
|
| 43 |
+
# Database functions
|
| 44 |
+
def init_db():
|
| 45 |
+
"""Create all tables"""
|
| 46 |
+
Base.metadata.create_all(bind=engine)
|
| 47 |
+
|
| 48 |
+
def get_db():
|
| 49 |
+
"""Get database session"""
|
| 50 |
+
db = SessionLocal()
|
| 51 |
+
try:
|
| 52 |
+
return db
|
| 53 |
+
finally:
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
def create_session(user_id: str, name: str = "New Chat"):
|
| 57 |
+
"""Create a new chat session"""
|
| 58 |
+
db = SessionLocal()
|
| 59 |
+
try:
|
| 60 |
+
new_session = Session(
|
| 61 |
+
id=str(uuid.uuid4()),
|
| 62 |
+
user_id=user_id,
|
| 63 |
+
name=name
|
| 64 |
+
)
|
| 65 |
+
db.add(new_session)
|
| 66 |
+
db.commit()
|
| 67 |
+
db.refresh(new_session)
|
| 68 |
+
return {
|
| 69 |
+
"id": new_session.id,
|
| 70 |
+
"user_id": new_session.user_id,
|
| 71 |
+
"name": new_session.name,
|
| 72 |
+
"created_at": new_session.created_at.isoformat() if new_session.created_at else None
|
| 73 |
+
}
|
| 74 |
+
finally:
|
| 75 |
+
db.close()
|
| 76 |
+
|
| 77 |
+
def get_sessions(user_id: str):
|
| 78 |
+
"""Get all sessions for a user"""
|
| 79 |
+
db = SessionLocal()
|
| 80 |
+
try:
|
| 81 |
+
sessions = db.query(Session).filter(Session.user_id == user_id).order_by(Session.created_at.desc()).all()
|
| 82 |
+
return [
|
| 83 |
+
{
|
| 84 |
+
"id": s.id,
|
| 85 |
+
"user_id": s.user_id,
|
| 86 |
+
"name": s.name,
|
| 87 |
+
"created_at": s.created_at.isoformat() if s.created_at else None
|
| 88 |
+
}
|
| 89 |
+
for s in sessions
|
| 90 |
+
]
|
| 91 |
+
finally:
|
| 92 |
+
db.close()
|
| 93 |
+
|
| 94 |
+
def add_message(session_id: str, role: str, content: str):
|
| 95 |
+
"""Add a message to a session"""
|
| 96 |
+
db = SessionLocal()
|
| 97 |
+
try:
|
| 98 |
+
message = Message(
|
| 99 |
+
session_id=session_id,
|
| 100 |
+
role=role,
|
| 101 |
+
content=content
|
| 102 |
+
)
|
| 103 |
+
db.add(message)
|
| 104 |
+
db.commit()
|
| 105 |
+
finally:
|
| 106 |
+
db.close()
|
| 107 |
+
|
| 108 |
+
def get_messages(session_id: str):
|
| 109 |
+
"""Get all messages for a session"""
|
| 110 |
+
db = SessionLocal()
|
| 111 |
+
try:
|
| 112 |
+
messages = db.query(Message).filter(Message.session_id == session_id).order_by(Message.created_at.asc()).all()
|
| 113 |
+
return [
|
| 114 |
+
{
|
| 115 |
+
"id": m.id,
|
| 116 |
+
"session_id": m.session_id,
|
| 117 |
+
"role": m.role,
|
| 118 |
+
"content": m.content,
|
| 119 |
+
"created_at": m.created_at.isoformat() if m.created_at else None
|
| 120 |
+
}
|
| 121 |
+
for m in messages
|
| 122 |
+
]
|
| 123 |
+
finally:
|
| 124 |
+
db.close()
|
| 125 |
+
|
| 126 |
+
def delete_session(session_id: str):
|
| 127 |
+
"""Delete a session and all its messages"""
|
| 128 |
+
db = SessionLocal()
|
| 129 |
+
try:
|
| 130 |
+
session = db.query(Session).filter(Session.id == session_id).first()
|
| 131 |
+
if session:
|
| 132 |
+
db.delete(session)
|
| 133 |
+
db.commit()
|
| 134 |
+
finally:
|
| 135 |
+
db.close()
|
main.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import FileResponse
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
import os
|
| 7 |
+
from . import chat, voice, database
|
| 8 |
+
import traceback
|
| 9 |
+
|
| 10 |
+
app = FastAPI()
|
| 11 |
+
|
| 12 |
+
app.add_middleware(
|
| 13 |
+
CORSMiddleware,
|
| 14 |
+
allow_origins=["*"],
|
| 15 |
+
allow_credentials=True,
|
| 16 |
+
allow_methods=["*"],
|
| 17 |
+
allow_headers=["*"],
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Initialize database on startup
|
| 21 |
+
@app.on_event("startup")
|
| 22 |
+
async def startup_event():
|
| 23 |
+
database.init_db()
|
| 24 |
+
|
| 25 |
+
class ChatRequest(BaseModel):
|
| 26 |
+
message: str
|
| 27 |
+
# history is now optional/deprecated as we use session_id, but keeping for backward compatibility if needed
|
| 28 |
+
history: List[dict] = []
|
| 29 |
+
|
| 30 |
+
class SessionCreateRequest(BaseModel):
|
| 31 |
+
name: str = "New Chat"
|
| 32 |
+
user_id: str
|
| 33 |
+
|
| 34 |
+
class VoiceRequest(BaseModel):
|
| 35 |
+
text: str
|
| 36 |
+
voice: str
|
| 37 |
+
|
| 38 |
+
@app.get("/")
|
| 39 |
+
async def root():
|
| 40 |
+
return {"message": "AI Girlfriend API"}
|
| 41 |
+
|
| 42 |
+
@app.get("/sessions/{user_id}")
|
| 43 |
+
async def get_sessions(user_id: str):
|
| 44 |
+
return database.get_sessions(user_id)
|
| 45 |
+
|
| 46 |
+
@app.post("/sessions")
|
| 47 |
+
async def create_session(request: SessionCreateRequest):
|
| 48 |
+
return database.create_session(request.user_id, request.name)
|
| 49 |
+
|
| 50 |
+
@app.delete("/sessions/{session_id}")
|
| 51 |
+
async def delete_session(session_id: str):
|
| 52 |
+
database.delete_session(session_id)
|
| 53 |
+
return {"message": "Session deleted"}
|
| 54 |
+
|
| 55 |
+
@app.get("/sessions/{session_id}/messages")
|
| 56 |
+
async def get_session_messages(session_id: str):
|
| 57 |
+
return database.get_messages(session_id)
|
| 58 |
+
|
| 59 |
+
@app.post("/sessions/{session_id}/chat")
|
| 60 |
+
async def chat_session_endpoint(session_id: str, request: ChatRequest):
|
| 61 |
+
try:
|
| 62 |
+
# Get existing history from DB
|
| 63 |
+
db_history = database.get_messages(session_id)
|
| 64 |
+
|
| 65 |
+
# Convert to format expected by chat.py (list of dicts with role/content)
|
| 66 |
+
history_for_model = [{"role": msg["role"], "content": msg["content"]} for msg in db_history]
|
| 67 |
+
|
| 68 |
+
# Get response from AI
|
| 69 |
+
response_text = chat.get_chat_response(request.message, history_for_model)
|
| 70 |
+
|
| 71 |
+
# Save user message and AI response to DB
|
| 72 |
+
database.add_message(session_id, "user", request.message)
|
| 73 |
+
database.add_message(session_id, "assistant", response_text)
|
| 74 |
+
|
| 75 |
+
return {"response": response_text}
|
| 76 |
+
except Exception as e:
|
| 77 |
+
traceback.print_exc()
|
| 78 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 79 |
+
|
| 80 |
+
# Legacy endpoint - keeping for now or can be removed if frontend is fully updated
|
| 81 |
+
@app.post("/chat")
|
| 82 |
+
async def chat_endpoint(request: ChatRequest):
|
| 83 |
+
try:
|
| 84 |
+
response = chat.get_chat_response(request.message, request.history)
|
| 85 |
+
return {"response": response}
|
| 86 |
+
except Exception as e:
|
| 87 |
+
traceback.print_exc() # Print error to console
|
| 88 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 89 |
+
|
| 90 |
+
@app.get("/voices")
|
| 91 |
+
async def voices_endpoint():
|
| 92 |
+
return voice.get_voices()
|
| 93 |
+
|
| 94 |
+
@app.post("/speak")
|
| 95 |
+
async def speak_endpoint(request: VoiceRequest):
|
| 96 |
+
try:
|
| 97 |
+
audio_path = await voice.generate_audio(request.text, request.voice)
|
| 98 |
+
return FileResponse(audio_path, media_type="audio/mpeg", filename="response.mp3")
|
| 99 |
+
except Exception as e:
|
| 100 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 101 |
+
|
| 102 |
+
from fastapi.staticfiles import StaticFiles
|
| 103 |
+
from fastapi.responses import FileResponse
|
| 104 |
+
|
| 105 |
+
# Mount static files if they exist (for production/served build)
|
| 106 |
+
frontend_dist = os.path.join(os.path.dirname(__file__), "../frontend/dist")
|
| 107 |
+
assets_path = os.path.join(frontend_dist, "assets")
|
| 108 |
+
|
| 109 |
+
if os.path.exists(assets_path):
|
| 110 |
+
app.mount("/assets", StaticFiles(directory=assets_path), name="assets")
|
| 111 |
+
|
| 112 |
+
@app.get("/{full_path:path}")
|
| 113 |
+
async def serve_frontend(full_path: str):
|
| 114 |
+
# Only serve if dist exists
|
| 115 |
+
if os.path.exists(frontend_dist):
|
| 116 |
+
# Serve index.html for any other path (SPA routing)
|
| 117 |
+
# Check if file exists in dist, else serve index.html
|
| 118 |
+
target_file = os.path.join(frontend_dist, full_path)
|
| 119 |
+
if full_path and os.path.exists(target_file):
|
| 120 |
+
return FileResponse(target_file)
|
| 121 |
+
return FileResponse(os.path.join(frontend_dist, "index.html"))
|
| 122 |
+
|
| 123 |
+
# If dist doesn't exist, just return a message or 404 for frontend routes
|
| 124 |
+
# This allows backend to run even if frontend isn't built
|
| 125 |
+
if full_path == "":
|
| 126 |
+
return {"message": "Backend is running. Frontend build not found. Use 'npm run dev' for frontend development."}
|
| 127 |
+
raise HTTPException(status_code=404, detail="Frontend build not found")
|
voice.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import edge_tts
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
async def generate_audio(text: str, voice: str) -> str:
|
| 6 |
+
communicate = edge_tts.Communicate(text, voice)
|
| 7 |
+
|
| 8 |
+
# Create a temporary file
|
| 9 |
+
# In production/spaces, we might want to manage this differently
|
| 10 |
+
fd, path = tempfile.mkstemp(suffix=".mp3")
|
| 11 |
+
os.close(fd)
|
| 12 |
+
|
| 13 |
+
await communicate.save(path)
|
| 14 |
+
return path
|
| 15 |
+
|
| 16 |
+
def get_voices():
|
| 17 |
+
# Return a list of available voices (simplified)
|
| 18 |
+
return [
|
| 19 |
+
{"name": "Ana (Female)", "id": "en-US-AnaNeural"},
|
| 20 |
+
{"name": "Andrew (Male)", "id": "en-US-AndrewMultilingualNeural"},
|
| 21 |
+
{"name": "Aria (Female)", "id": "en-US-AriaNeural"},
|
| 22 |
+
{"name": "Ava (Female)", "id": "en-US-AvaMultilingualNeural"},
|
| 23 |
+
{"name": "Brian (Male)", "id": "en-US-BrianMultilingualNeural"},
|
| 24 |
+
{"name": "Christopher (Male)", "id": "en-US-ChristopherNeural"},
|
| 25 |
+
{"name": "Emma (Female)", "id": "en-US-EmmaMultilingualNeural"},
|
| 26 |
+
{"name": "Eric (Male)", "id": "en-US-EricNeural"},
|
| 27 |
+
{"name": "Guy (Male)", "id": "en-US-GuyNeural"},
|
| 28 |
+
{"name": "Jenny (Female)", "id": "en-US-JennyNeural"},
|
| 29 |
+
{"name": "Michelle (Female)", "id": "en-US-MichelleNeural"},
|
| 30 |
+
{"name": "Roger (Male)", "id": "en-US-RogerNeural"},
|
| 31 |
+
{"name": "Steffan (Male)", "id": "en-US-SteffanNeural"},
|
| 32 |
+
]
|