Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, BackgroundTasks, Form | |
| from fastapi.concurrency import run_in_threadpool | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Dict, Any, Optional | |
| import os | |
| import shutil | |
| from datetime import datetime | |
| from dotenv import load_dotenv | |
| from sqlalchemy.orm import Session | |
| load_dotenv("../.env") # Load from root | |
| load_dotenv(".env", override=True) # Load from local backend .env (prioritize) | |
| from agent import app as agent_app, vector_store | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from langchain_community.document_loaders import PyPDFLoader, TextLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from database import get_db, init_db, Conversation, Message as DBMessage | |
| app = FastAPI() | |
| async def startup_event(): | |
| # Initialize database tables | |
| init_db() | |
| tavily_key = os.getenv("TAVILY_API_KEY") | |
| if tavily_key: | |
| print(f"Startup: TAVILY_API_KEY found: {tavily_key[:5]}...") | |
| else: | |
| print("Startup: TAVILY_API_KEY NOT found!") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| history: List[Dict[str, str]] = [] | |
| conversation_id: Optional[str] = None | |
| user_id: Optional[str] = None | |
| class ConversationCreate(BaseModel): | |
| user_id: str | |
| title: str = "New Chat" | |
| class ConversationResponse(BaseModel): | |
| id: str | |
| user_id: str | |
| title: str | |
| created_at: str | |
| updated_at: str | |
| message_count: int = 0 | |
| summary: Optional[str] = None | |
| async def upload_file(file: UploadFile = File(...), conversation_id: str = Form(...)): | |
| print(f"DEBUG: Uploading file {file.filename} to conversation {conversation_id}") | |
| if not conversation_id or conversation_id == "null" or conversation_id == "undefined": | |
| print("ERROR: Invalid conversation_id received in upload_file") | |
| raise HTTPException(status_code=400, detail="Please start a conversation first!") | |
| try: | |
| # Save file temporarily | |
| file_path = f"temp_{file.filename}" | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| # Process file | |
| from agent import upload_file as agent_upload | |
| splits = agent_upload(file_path, conversation_id) | |
| # Cleanup | |
| os.remove(file_path) | |
| return {"status": "success", "message": f"Processed {len(splits)} chunks"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def clear_vector_store_endpoint(): | |
| try: | |
| from agent import clear_vector_store | |
| success = clear_vector_store() | |
| if success: | |
| return {"status": "success", "message": "Vector store cleared"} | |
| else: | |
| raise HTTPException(status_code=500, detail="Failed to clear vector store") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_conversation_summary(conversation_id: str, db: Session): | |
| """Background task to generate a summary for a conversation.""" | |
| try: | |
| # Get messages | |
| messages = db.query(DBMessage).filter( | |
| DBMessage.conversation_id == conversation_id | |
| ).order_by(DBMessage.created_at).limit(10).all() # Limit to first 10 for summary | |
| if not messages: | |
| return | |
| conversation_text = "\n".join([f"{msg.role}: {msg.content}" for msg in messages]) | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) | |
| system = """You are a helpful assistant. Create a very short, 1-sentence summary (max 10 words) of this conversation topic. | |
| Example: "Python script debugging", "Recipe for chocolate cake", "Travel plans to Japan". | |
| """ | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system), | |
| ("human", "Conversation:\n{text}") | |
| ]) | |
| chain = prompt | llm | StrOutputParser() | |
| summary = chain.invoke({"text": conversation_text}) | |
| # Update conversation | |
| conversation = db.query(Conversation).filter(Conversation.id == conversation_id).first() | |
| if conversation: | |
| conversation.summary = summary.strip() | |
| db.commit() | |
| print(f"Generated summary for {conversation_id}: {summary}") | |
| except Exception as e: | |
| print(f"Error generating summary: {e}") | |
| def chat_endpoint(request: ChatRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)): | |
| try: | |
| # Convert history to LangChain messages | |
| messages = [] | |
| for msg in request.history: | |
| if msg["role"] == "user": | |
| messages.append(HumanMessage(content=msg["content"])) | |
| elif msg["role"] == "assistant": | |
| messages.append(AIMessage(content=msg["content"])) | |
| # Add current message | |
| messages.append(HumanMessage(content=request.message)) | |
| # Invoke Agent | |
| # Deep Research Graph expects 'task' | |
| inputs = { | |
| "task": request.message, | |
| "plan": [], | |
| "content": [], | |
| "revision_number": 0, | |
| "max_revisions": 2, | |
| "steps": [], | |
| "messages": [HumanMessage(content=request.message)], | |
| "youtube_url": "", | |
| "youtube_captions": "", | |
| "deep_research": False, # Will be set by router | |
| "conversation_id": request.conversation_id | |
| } | |
| result = agent_app.invoke(inputs) | |
| # Get final report | |
| final_response = result.get("final_report", "No report generated.") | |
| # Extract steps | |
| steps = result.get("steps", []) | |
| thoughts = [] | |
| for step in steps: | |
| thoughts.append({ | |
| "tool": "agent_step", | |
| "input": step, | |
| "status": "completed" | |
| }) | |
| # Save to database if conversation_id and user_id provided | |
| if db and request.conversation_id and request.user_id: | |
| try: | |
| # Verify conversation exists and belongs to user | |
| conversation = db.query(Conversation).filter( | |
| Conversation.id == request.conversation_id, | |
| Conversation.user_id == request.user_id | |
| ).first() | |
| if conversation: | |
| # Save user message | |
| user_msg = DBMessage( | |
| conversation_id=request.conversation_id, | |
| role="user", | |
| content=request.message | |
| ) | |
| db.add(user_msg) | |
| # Save assistant message | |
| assistant_msg = DBMessage( | |
| conversation_id=request.conversation_id, | |
| role="assistant", | |
| content=final_response, | |
| thoughts=thoughts if thoughts else None | |
| ) | |
| db.add(assistant_msg) | |
| # Update conversation timestamp | |
| conversation.updated_at = datetime.utcnow() | |
| db.commit() | |
| # Trigger summary generation if it's the first few messages or summary is missing | |
| # We can check message count or just do it periodically | |
| # For simplicity, let's do it if message count is small (< 5) or summary is None | |
| message_count = db.query(DBMessage).filter(DBMessage.conversation_id == request.conversation_id).count() | |
| if message_count <= 4 or not conversation.summary: | |
| background_tasks.add_task(generate_conversation_summary, request.conversation_id, db) | |
| except Exception as db_error: | |
| print(f"Database error: {db_error}") | |
| db.rollback() | |
| return {"response": final_response, "thoughts": thoughts} | |
| except Exception as e: | |
| print(f"Error in chat endpoint: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| class SummarizeRequest(BaseModel): | |
| content: str | |
| async def summarize_endpoint(request: SummarizeRequest): | |
| try: | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) | |
| system = """You are a professional summarizer. Create a concise summary of the provided content. | |
| Guidelines: | |
| 1. Keep it to 3-5 sentences | |
| 2. Capture the main points and key takeaways | |
| 3. Use clear, simple language | |
| 4. Maintain the professional tone | |
| """ | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system), | |
| ("human", "Summarize this content:\n\n{content}") | |
| ]) | |
| chain = prompt | llm | StrOutputParser() | |
| summary = chain.invoke({"content": request.content}) | |
| return {"summary": summary} | |
| except Exception as e: | |
| print(f"Error in summarize endpoint: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ==================== | |
| # CONVERSATION ENDPOINTS | |
| # ==================== | |
| async def create_conversation(conv: ConversationCreate, db: Session = Depends(get_db)): | |
| """Create a new conversation for a user.""" | |
| if not db: | |
| raise HTTPException(status_code=503, detail="Database not configured") | |
| try: | |
| new_conv = Conversation( | |
| user_id=conv.user_id, | |
| title=conv.title | |
| ) | |
| db.add(new_conv) | |
| db.commit() | |
| db.refresh(new_conv) | |
| return { | |
| "id": new_conv.id, | |
| "user_id": new_conv.user_id, | |
| "title": new_conv.title, | |
| "created_at": new_conv.created_at.isoformat(), | |
| "updated_at": new_conv.updated_at.isoformat(), | |
| "message_count": 0, | |
| "summary": None | |
| } | |
| except Exception as e: | |
| db.rollback() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_conversations(user_id: str, db: Session = Depends(get_db)): | |
| """Get all conversations for a user.""" | |
| if not db: | |
| return [] | |
| try: | |
| conversations = db.query(Conversation).filter( | |
| Conversation.user_id == user_id | |
| ).order_by(Conversation.updated_at.desc()).all() | |
| result = [] | |
| for conv in conversations: | |
| message_count = db.query(DBMessage).filter( | |
| DBMessage.conversation_id == conv.id | |
| ).count() | |
| result.append({ | |
| "id": conv.id, | |
| "user_id": conv.user_id, | |
| "title": conv.title, | |
| "created_at": conv.created_at.isoformat(), | |
| "updated_at": conv.updated_at.isoformat(), | |
| "message_count": message_count, | |
| "summary": conv.summary | |
| }) | |
| return result | |
| except Exception as e: | |
| print(f"Error fetching conversations: {e}") | |
| return [] | |
| async def get_messages(conversation_id: str, user_id: str, db: Session = Depends(get_db)): | |
| """Get all messages for a conversation.""" | |
| if not db: | |
| return [] | |
| try: | |
| # Verify conversation belongs to user | |
| conversation = db.query(Conversation).filter( | |
| Conversation.id == conversation_id, | |
| Conversation.user_id == user_id | |
| ).first() | |
| if not conversation: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| messages = db.query(DBMessage).filter( | |
| DBMessage.conversation_id == conversation_id | |
| ).order_by(DBMessage.created_at).all() | |
| return [{ | |
| "id": msg.id, | |
| "role": msg.role, | |
| "content": msg.content, | |
| "thoughts": msg.thoughts, | |
| "created_at": msg.created_at.isoformat() | |
| } for msg in messages] | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Error fetching messages: {e}") | |
| return [] | |
| async def delete_conversation(conversation_id: str, user_id: str, db: Session = Depends(get_db)): | |
| """Delete a conversation and all its messages.""" | |
| if not db: | |
| raise HTTPException(status_code=503, detail="Database not configured") | |
| try: | |
| conversation = db.query(Conversation).filter( | |
| Conversation.id == conversation_id, | |
| Conversation.user_id == user_id | |
| ).first() | |
| if not conversation: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| db.delete(conversation) | |
| db.commit() | |
| return {"status": "success", "message": "Conversation deleted"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| db.rollback() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| return {"status": "ok"} | |
| async def root(): | |
| return {"message": "RAG Backend is running"} | |
| # Serve static files (Frontend) - to be configured after build | |
| # app.mount("/", StaticFiles(directory="../frontend/out", html=True), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |