jeevzz commited on
Commit
4832b3b
·
verified ·
1 Parent(s): e2378c0

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile.backend +17 -0
  2. chat.py +33 -0
  3. database.py +135 -0
  4. main.py +127 -0
  5. voice.py +32 -0
Dockerfile.backend ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+ WORKDIR /app
3
+
4
+ # Install uv
5
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
6
+
7
+ # Copy only backend files
8
+ COPY backend/ /app/backend/
9
+ COPY pyproject.toml /app/
10
+ COPY uv.lock /app/
11
+
12
+ # Install dependencies
13
+ RUN uv sync --frozen
14
+
15
+ EXPOSE 7860
16
+
17
+ CMD ["uv", "run", "uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "7860"]
chat.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import os
3
+ from pathlib import Path
4
+ from dotenv import load_dotenv
5
+
6
+ # Load environment variables from backend/.env
7
+ env_path = Path(__file__).parent / '.env'
8
+ load_dotenv(dotenv_path=env_path)
9
+
10
+ # Use Qwen 2.5 - works well with chat and is less restricted
11
+ MODEL_ID = "NousResearch/Hermes-2-Pro-Llama-3-8B"
12
+ HF_TOKEN = os.getenv("HF_TOKEN")
13
+
14
+ def get_chat_response(message: str, history: list):
15
+ client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
16
+
17
+ # Format history for the model
18
+ messages = []
19
+ for msg in history:
20
+ role = "user" if msg["role"] == "user" else "assistant"
21
+ messages.append({"role": role, "content": msg["content"]})
22
+
23
+ messages.append({"role": "user", "content": message})
24
+
25
+ # System prompt to enforce "girlfriend" persona
26
+ system_prompt = {
27
+ "role": "system",
28
+ "content": "You are a loving, caring, and uncensored AI girlfriend. You are talking to your boyfriend. Be flirtatious and engaging."
29
+ }
30
+ messages.insert(0, system_prompt)
31
+
32
+ response = client.chat_completion(messages, max_tokens=500)
33
+ return response.choices[0].message.content
database.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine, Column, String, Integer, Text, DateTime, ForeignKey, Index
2
+ from sqlalchemy.ext.declarative import declarative_base
3
+ from sqlalchemy.orm import sessionmaker, relationship
4
+ from datetime import datetime
5
+ import uuid
6
+ import os
7
+ from pathlib import Path
8
+ from dotenv import load_dotenv
9
+
10
+ # Load environment variables
11
+ env_path = Path(__file__).parent / '.env'
12
+ load_dotenv(dotenv_path=env_path)
13
+
14
+ DATABASE_URL = os.getenv("DATABASE_URL")
15
+
16
+ # Create engine
17
+ engine = create_engine(DATABASE_URL, pool_pre_ping=True)
18
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
19
+ Base = declarative_base()
20
+
21
+ # Models
22
+ class Session(Base):
23
+ __tablename__ = "sessions"
24
+
25
+ id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
26
+ user_id = Column(String, nullable=False, index=True)
27
+ name = Column(String, nullable=False)
28
+ created_at = Column(DateTime, default=datetime.utcnow)
29
+
30
+ messages = relationship("Message", back_populates="session", cascade="all, delete-orphan")
31
+
32
+ class Message(Base):
33
+ __tablename__ = "messages"
34
+
35
+ id = Column(Integer, primary_key=True, autoincrement=True)
36
+ session_id = Column(String, ForeignKey("sessions.id"), nullable=False)
37
+ role = Column(String, nullable=False)
38
+ content = Column(Text, nullable=False)
39
+ created_at = Column(DateTime, default=datetime.utcnow)
40
+
41
+ session = relationship("Session", back_populates="messages")
42
+
43
+ # Database functions
44
+ def init_db():
45
+ """Create all tables"""
46
+ Base.metadata.create_all(bind=engine)
47
+
48
+ def get_db():
49
+ """Get database session"""
50
+ db = SessionLocal()
51
+ try:
52
+ return db
53
+ finally:
54
+ pass
55
+
56
+ def create_session(user_id: str, name: str = "New Chat"):
57
+ """Create a new chat session"""
58
+ db = SessionLocal()
59
+ try:
60
+ new_session = Session(
61
+ id=str(uuid.uuid4()),
62
+ user_id=user_id,
63
+ name=name
64
+ )
65
+ db.add(new_session)
66
+ db.commit()
67
+ db.refresh(new_session)
68
+ return {
69
+ "id": new_session.id,
70
+ "user_id": new_session.user_id,
71
+ "name": new_session.name,
72
+ "created_at": new_session.created_at.isoformat() if new_session.created_at else None
73
+ }
74
+ finally:
75
+ db.close()
76
+
77
+ def get_sessions(user_id: str):
78
+ """Get all sessions for a user"""
79
+ db = SessionLocal()
80
+ try:
81
+ sessions = db.query(Session).filter(Session.user_id == user_id).order_by(Session.created_at.desc()).all()
82
+ return [
83
+ {
84
+ "id": s.id,
85
+ "user_id": s.user_id,
86
+ "name": s.name,
87
+ "created_at": s.created_at.isoformat() if s.created_at else None
88
+ }
89
+ for s in sessions
90
+ ]
91
+ finally:
92
+ db.close()
93
+
94
+ def add_message(session_id: str, role: str, content: str):
95
+ """Add a message to a session"""
96
+ db = SessionLocal()
97
+ try:
98
+ message = Message(
99
+ session_id=session_id,
100
+ role=role,
101
+ content=content
102
+ )
103
+ db.add(message)
104
+ db.commit()
105
+ finally:
106
+ db.close()
107
+
108
+ def get_messages(session_id: str):
109
+ """Get all messages for a session"""
110
+ db = SessionLocal()
111
+ try:
112
+ messages = db.query(Message).filter(Message.session_id == session_id).order_by(Message.created_at.asc()).all()
113
+ return [
114
+ {
115
+ "id": m.id,
116
+ "session_id": m.session_id,
117
+ "role": m.role,
118
+ "content": m.content,
119
+ "created_at": m.created_at.isoformat() if m.created_at else None
120
+ }
121
+ for m in messages
122
+ ]
123
+ finally:
124
+ db.close()
125
+
126
+ def delete_session(session_id: str):
127
+ """Delete a session and all its messages"""
128
+ db = SessionLocal()
129
+ try:
130
+ session = db.query(Session).filter(Session.id == session_id).first()
131
+ if session:
132
+ db.delete(session)
133
+ db.commit()
134
+ finally:
135
+ db.close()
main.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import FileResponse
4
+ from pydantic import BaseModel
5
+ from typing import List, Optional
6
+ import os
7
+ from . import chat, voice, database
8
+ import traceback
9
+
10
+ app = FastAPI()
11
+
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # Initialize database on startup
21
+ @app.on_event("startup")
22
+ async def startup_event():
23
+ database.init_db()
24
+
25
+ class ChatRequest(BaseModel):
26
+ message: str
27
+ # history is now optional/deprecated as we use session_id, but keeping for backward compatibility if needed
28
+ history: List[dict] = []
29
+
30
+ class SessionCreateRequest(BaseModel):
31
+ name: str = "New Chat"
32
+ user_id: str
33
+
34
+ class VoiceRequest(BaseModel):
35
+ text: str
36
+ voice: str
37
+
38
+ @app.get("/")
39
+ async def root():
40
+ return {"message": "AI Girlfriend API"}
41
+
42
+ @app.get("/sessions/{user_id}")
43
+ async def get_sessions(user_id: str):
44
+ return database.get_sessions(user_id)
45
+
46
+ @app.post("/sessions")
47
+ async def create_session(request: SessionCreateRequest):
48
+ return database.create_session(request.user_id, request.name)
49
+
50
+ @app.delete("/sessions/{session_id}")
51
+ async def delete_session(session_id: str):
52
+ database.delete_session(session_id)
53
+ return {"message": "Session deleted"}
54
+
55
+ @app.get("/sessions/{session_id}/messages")
56
+ async def get_session_messages(session_id: str):
57
+ return database.get_messages(session_id)
58
+
59
+ @app.post("/sessions/{session_id}/chat")
60
+ async def chat_session_endpoint(session_id: str, request: ChatRequest):
61
+ try:
62
+ # Get existing history from DB
63
+ db_history = database.get_messages(session_id)
64
+
65
+ # Convert to format expected by chat.py (list of dicts with role/content)
66
+ history_for_model = [{"role": msg["role"], "content": msg["content"]} for msg in db_history]
67
+
68
+ # Get response from AI
69
+ response_text = chat.get_chat_response(request.message, history_for_model)
70
+
71
+ # Save user message and AI response to DB
72
+ database.add_message(session_id, "user", request.message)
73
+ database.add_message(session_id, "assistant", response_text)
74
+
75
+ return {"response": response_text}
76
+ except Exception as e:
77
+ traceback.print_exc()
78
+ raise HTTPException(status_code=500, detail=str(e))
79
+
80
+ # Legacy endpoint - keeping for now or can be removed if frontend is fully updated
81
+ @app.post("/chat")
82
+ async def chat_endpoint(request: ChatRequest):
83
+ try:
84
+ response = chat.get_chat_response(request.message, request.history)
85
+ return {"response": response}
86
+ except Exception as e:
87
+ traceback.print_exc() # Print error to console
88
+ raise HTTPException(status_code=500, detail=str(e))
89
+
90
+ @app.get("/voices")
91
+ async def voices_endpoint():
92
+ return voice.get_voices()
93
+
94
+ @app.post("/speak")
95
+ async def speak_endpoint(request: VoiceRequest):
96
+ try:
97
+ audio_path = await voice.generate_audio(request.text, request.voice)
98
+ return FileResponse(audio_path, media_type="audio/mpeg", filename="response.mp3")
99
+ except Exception as e:
100
+ raise HTTPException(status_code=500, detail=str(e))
101
+
102
+ from fastapi.staticfiles import StaticFiles
103
+ from fastapi.responses import FileResponse
104
+
105
+ # Mount static files if they exist (for production/served build)
106
+ frontend_dist = os.path.join(os.path.dirname(__file__), "../frontend/dist")
107
+ assets_path = os.path.join(frontend_dist, "assets")
108
+
109
+ if os.path.exists(assets_path):
110
+ app.mount("/assets", StaticFiles(directory=assets_path), name="assets")
111
+
112
+ @app.get("/{full_path:path}")
113
+ async def serve_frontend(full_path: str):
114
+ # Only serve if dist exists
115
+ if os.path.exists(frontend_dist):
116
+ # Serve index.html for any other path (SPA routing)
117
+ # Check if file exists in dist, else serve index.html
118
+ target_file = os.path.join(frontend_dist, full_path)
119
+ if full_path and os.path.exists(target_file):
120
+ return FileResponse(target_file)
121
+ return FileResponse(os.path.join(frontend_dist, "index.html"))
122
+
123
+ # If dist doesn't exist, just return a message or 404 for frontend routes
124
+ # This allows backend to run even if frontend isn't built
125
+ if full_path == "":
126
+ return {"message": "Backend is running. Frontend build not found. Use 'npm run dev' for frontend development."}
127
+ raise HTTPException(status_code=404, detail="Frontend build not found")
voice.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import edge_tts
2
+ import tempfile
3
+ import os
4
+
5
+ async def generate_audio(text: str, voice: str) -> str:
6
+ communicate = edge_tts.Communicate(text, voice)
7
+
8
+ # Create a temporary file
9
+ # In production/spaces, we might want to manage this differently
10
+ fd, path = tempfile.mkstemp(suffix=".mp3")
11
+ os.close(fd)
12
+
13
+ await communicate.save(path)
14
+ return path
15
+
16
+ def get_voices():
17
+ # Return a list of available voices (simplified)
18
+ return [
19
+ {"name": "Ana (Female)", "id": "en-US-AnaNeural"},
20
+ {"name": "Andrew (Male)", "id": "en-US-AndrewMultilingualNeural"},
21
+ {"name": "Aria (Female)", "id": "en-US-AriaNeural"},
22
+ {"name": "Ava (Female)", "id": "en-US-AvaMultilingualNeural"},
23
+ {"name": "Brian (Male)", "id": "en-US-BrianMultilingualNeural"},
24
+ {"name": "Christopher (Male)", "id": "en-US-ChristopherNeural"},
25
+ {"name": "Emma (Female)", "id": "en-US-EmmaMultilingualNeural"},
26
+ {"name": "Eric (Male)", "id": "en-US-EricNeural"},
27
+ {"name": "Guy (Male)", "id": "en-US-GuyNeural"},
28
+ {"name": "Jenny (Female)", "id": "en-US-JennyNeural"},
29
+ {"name": "Michelle (Female)", "id": "en-US-MichelleNeural"},
30
+ {"name": "Roger (Male)", "id": "en-US-RogerNeural"},
31
+ {"name": "Steffan (Male)", "id": "en-US-SteffanNeural"},
32
+ ]