Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- Dockerfile +5 -9
- add_column.py +35 -0
- agent.py +153 -30
- api_test.py +57 -0
- check_yt.py +65 -0
- database.py +1 -0
- main.py +81 -22
- rag_test.py +50 -0
Dockerfile
CHANGED
|
@@ -10,19 +10,15 @@ RUN apt-get update && apt-get install -y \
|
|
| 10 |
# Install uv for faster package installation
|
| 11 |
RUN pip install uv
|
| 12 |
|
| 13 |
-
# Copy requirements and install dependencies
|
| 14 |
-
# Assumes requirements.txt is at the root of the repo
|
| 15 |
COPY requirements.txt .
|
| 16 |
RUN uv pip install --system -r requirements.txt
|
| 17 |
|
| 18 |
-
# Copy
|
| 19 |
-
|
| 20 |
-
COPY main.py .
|
| 21 |
-
COPY agent.py .
|
| 22 |
-
COPY database.py .
|
| 23 |
|
| 24 |
# Expose port 7860 (required by HF Spaces)
|
| 25 |
EXPOSE 7860
|
| 26 |
|
| 27 |
-
# Run the application
|
| 28 |
-
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
|
| 10 |
# Install uv for faster package installation
|
| 11 |
RUN pip install uv
|
| 12 |
|
| 13 |
+
# Copy requirements and install dependencies with uv
|
|
|
|
| 14 |
COPY requirements.txt .
|
| 15 |
RUN uv pip install --system -r requirements.txt
|
| 16 |
|
| 17 |
+
# Copy the entire backend package (preserves module structure)
|
| 18 |
+
COPY backend/ ./backend/
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# Expose port 7860 (required by HF Spaces)
|
| 21 |
EXPOSE 7860
|
| 22 |
|
| 23 |
+
# Run the application – note the module path includes the package name
|
| 24 |
+
CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
add_column.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from sqlalchemy import create_engine, text
|
| 4 |
+
|
| 5 |
+
# Load environment variables
|
| 6 |
+
load_dotenv(".env", override=True)
|
| 7 |
+
load_dotenv("../.env", override=False)
|
| 8 |
+
|
| 9 |
+
DATABASE_URL = os.getenv("DATABASE_URL")
|
| 10 |
+
|
| 11 |
+
if not DATABASE_URL:
|
| 12 |
+
print("Error: DATABASE_URL not found")
|
| 13 |
+
exit(1)
|
| 14 |
+
|
| 15 |
+
def add_column():
|
| 16 |
+
engine = create_engine(DATABASE_URL)
|
| 17 |
+
with engine.connect() as conn:
|
| 18 |
+
try:
|
| 19 |
+
# Check if column exists first to avoid error
|
| 20 |
+
check_sql = text("SELECT column_name FROM information_schema.columns WHERE table_name='conversations' AND column_name='summary';")
|
| 21 |
+
result = conn.execute(check_sql)
|
| 22 |
+
if result.fetchone():
|
| 23 |
+
print("Column 'summary' already exists.")
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
print("Adding 'summary' column to 'conversations' table...")
|
| 27 |
+
sql = text("ALTER TABLE conversations ADD COLUMN summary TEXT;")
|
| 28 |
+
conn.execute(sql)
|
| 29 |
+
conn.commit()
|
| 30 |
+
print("Successfully added 'summary' column.")
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"Error adding column: {e}")
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
add_column()
|
agent.py
CHANGED
|
@@ -21,8 +21,6 @@ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
|
|
| 21 |
from youtube_transcript_api import YouTubeTranscriptApi
|
| 22 |
import yt_dlp
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
# --- Configuration ---
|
| 27 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
|
| 28 |
|
|
@@ -31,6 +29,20 @@ embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
|
|
| 31 |
vector_store = Chroma(embedding_function=embeddings, persist_directory="./chroma_db")
|
| 32 |
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# --- State Definition ---
|
| 35 |
class AgentState(TypedDict):
|
| 36 |
"""The state of our Deep Research Agent."""
|
|
@@ -45,6 +57,29 @@ class AgentState(TypedDict):
|
|
| 45 |
youtube_url: str
|
| 46 |
youtube_captions: str
|
| 47 |
deep_research: bool # Flag to indicate if deep research is requested
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# --- Data Models ---
|
| 50 |
class Plan(BaseModel):
|
|
@@ -63,10 +98,14 @@ def extract_video_id(url):
|
|
| 63 |
|
| 64 |
def get_video_duration(url):
|
| 65 |
"""Gets video duration in seconds using yt-dlp."""
|
| 66 |
-
ydl_opts = {'quiet': True}
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# --- Nodes ---
|
| 72 |
|
|
@@ -118,29 +157,91 @@ def youtube_node(state: AgentState):
|
|
| 118 |
# Check Duration
|
| 119 |
try:
|
| 120 |
duration = get_video_duration(url)
|
| 121 |
-
if duration >
|
| 122 |
-
return {"final_report": f"Error: Video is too long ({duration//60} mins). Limit is
|
| 123 |
except Exception as e:
|
| 124 |
-
|
|
|
|
| 125 |
|
| 126 |
# Get Captions
|
|
|
|
| 127 |
try:
|
| 128 |
-
print(f"
|
| 129 |
-
|
| 130 |
-
transcript_list = yt.list(video_id)
|
| 131 |
-
# Try to find English, or fallback to first available
|
| 132 |
try:
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
except Exception as e:
|
| 142 |
-
print(f"
|
| 143 |
-
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# Generate Title
|
| 146 |
system = """You are a YouTube Expert. Analyze the provided video transcript and generate 3 catchy, AI-enhanced title options.
|
|
@@ -160,7 +261,11 @@ def youtube_node(state: AgentState):
|
|
| 160 |
)
|
| 161 |
|
| 162 |
chain = prompt | llm | StrOutputParser()
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# Manually construct a beautifully formatted report with MAXIMUM SPACING
|
| 166 |
report = "# YouTube Video Analysis\n\n\n"
|
|
@@ -180,11 +285,21 @@ def youtube_node(state: AgentState):
|
|
| 180 |
if "caption" in task.lower() or "transcript" in task.lower():
|
| 181 |
report += "---\n\n\n"
|
| 182 |
report += "## 📝 Full Captions\n\n\n"
|
| 183 |
-
report += f"```text\n{transcript_text}
|
| 184 |
else:
|
| 185 |
report += "---\n\n\n"
|
| 186 |
report += "> **Note:** Captions are available for this video! Add 'with captions' to your request to see them.\n\n\n"
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
return {
|
| 189 |
"final_report": report,
|
| 190 |
"youtube_captions": transcript_text,
|
|
@@ -203,7 +318,11 @@ def quick_response_node(state: AgentState):
|
|
| 203 |
|
| 204 |
# Try to get relevant context from vector store
|
| 205 |
try:
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
context = "\n\n".join([d.page_content for d in docs]) if docs else ""
|
| 208 |
except Exception as e:
|
| 209 |
print(f"Retriever error: {e}")
|
|
@@ -211,7 +330,7 @@ def quick_response_node(state: AgentState):
|
|
| 211 |
|
| 212 |
# Check if we should do a quick web search (for real-time info)
|
| 213 |
# If context is empty OR if the query implies real-time data
|
| 214 |
-
real_time_keywords = ["price", "current", "news", "latest", "today", "now", "live", "rate", "stock", "weather", "forecast", "score", "result", "vs", "when", "where", "who"]
|
| 215 |
should_search = any(k in task.lower() for k in real_time_keywords)
|
| 216 |
|
| 217 |
web_context = ""
|
|
@@ -246,6 +365,7 @@ def quick_response_node(state: AgentState):
|
|
| 246 |
4. If the user asks for "price", "news", or "current" info, prioritize the Web Search Results.
|
| 247 |
5. Keep responses focused. Do NOT write a long report.
|
| 248 |
6. If the question requires extensive research, suggest the user ask for "deep research".
|
|
|
|
| 249 |
"""
|
| 250 |
|
| 251 |
if full_context:
|
|
@@ -321,7 +441,10 @@ def research_node(state: AgentState):
|
|
| 321 |
steps_log.append(f"Researching: {step}")
|
| 322 |
|
| 323 |
# 1. Try Vector Store first
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
| 325 |
if docs:
|
| 326 |
context = "\n".join([d.page_content for d in docs])
|
| 327 |
content.append(f"Source: Local Documents\nTopic: {step}\nContent: {context}")
|
|
@@ -361,9 +484,9 @@ def writer_node(state: AgentState):
|
|
| 361 |
2. **Headers**: Use headers (##, ###) to organize sections. **IMPORTANT**: Always add a blank line before and after every header.
|
| 362 |
3. **Content**: Synthesize the information. Do not just list facts.
|
| 363 |
4. **Formatting**:
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
5. **Citations**: If the notes mention specific sources, cite them.
|
| 368 |
6. **Conclusion**: End with a strong conclusion.
|
| 369 |
|
|
|
|
| 21 |
from youtube_transcript_api import YouTubeTranscriptApi
|
| 22 |
import yt_dlp
|
| 23 |
|
|
|
|
|
|
|
| 24 |
# --- Configuration ---
|
| 25 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
|
| 26 |
|
|
|
|
| 29 |
vector_store = Chroma(embedding_function=embeddings, persist_directory="./chroma_db")
|
| 30 |
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
|
| 31 |
|
| 32 |
+
def clear_vector_store():
|
| 33 |
+
"""Clears the vector store."""
|
| 34 |
+
global vector_store, retriever
|
| 35 |
+
try:
|
| 36 |
+
# Delete the collection
|
| 37 |
+
vector_store.delete_collection()
|
| 38 |
+
# Re-initialize
|
| 39 |
+
vector_store = Chroma(embedding_function=embeddings, persist_directory="./chroma_db")
|
| 40 |
+
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
|
| 41 |
+
return True
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Error clearing vector store: {e}")
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
# --- State Definition ---
|
| 47 |
class AgentState(TypedDict):
|
| 48 |
"""The state of our Deep Research Agent."""
|
|
|
|
| 57 |
youtube_url: str
|
| 58 |
youtube_captions: str
|
| 59 |
deep_research: bool # Flag to indicate if deep research is requested
|
| 60 |
+
conversation_id: str # For RAG isolation
|
| 61 |
+
|
| 62 |
+
# --- File Processing ---
|
| 63 |
+
from langchain_community.document_loaders import PyPDFLoader, TextLoader
|
| 64 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 65 |
+
|
| 66 |
+
def upload_file(file_path: str, conversation_id: str):
|
| 67 |
+
"""Process uploaded file and add to vector store with metadata."""
|
| 68 |
+
if file_path.endswith(".pdf"):
|
| 69 |
+
loader = PyPDFLoader(file_path)
|
| 70 |
+
else:
|
| 71 |
+
loader = TextLoader(file_path)
|
| 72 |
+
|
| 73 |
+
docs = loader.load()
|
| 74 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
| 75 |
+
splits = text_splitter.split_documents(docs)
|
| 76 |
+
|
| 77 |
+
# Add metadata
|
| 78 |
+
for split in splits:
|
| 79 |
+
split.metadata["conversation_id"] = conversation_id
|
| 80 |
+
|
| 81 |
+
vector_store.add_documents(splits)
|
| 82 |
+
return splits
|
| 83 |
|
| 84 |
# --- Data Models ---
|
| 85 |
class Plan(BaseModel):
|
|
|
|
| 98 |
|
| 99 |
def get_video_duration(url):
|
| 100 |
"""Gets video duration in seconds using yt-dlp."""
|
| 101 |
+
ydl_opts = {'quiet': True, 'no_warnings': True}
|
| 102 |
+
try:
|
| 103 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 104 |
+
info = ydl.extract_info(url, download=False)
|
| 105 |
+
return info.get('duration', 0)
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"Warning: Could not check duration: {e}")
|
| 108 |
+
return 0 # Return 0 to skip duration check on error
|
| 109 |
|
| 110 |
# --- Nodes ---
|
| 111 |
|
|
|
|
| 157 |
# Check Duration
|
| 158 |
try:
|
| 159 |
duration = get_video_duration(url)
|
| 160 |
+
if duration > 1200: # 20 minutes limit (increased)
|
| 161 |
+
return {"final_report": f"Error: Video is too long ({duration//60} mins). Limit is 20 minutes.", "steps": ["Video rejected: Too long"]}
|
| 162 |
except Exception as e:
|
| 163 |
+
print(f"Error checking duration: {e}")
|
| 164 |
+
# Continue anyway if duration check fails (might be network issue)
|
| 165 |
|
| 166 |
# Get Captions
|
| 167 |
+
transcript_text = ""
|
| 168 |
try:
|
| 169 |
+
print(f"Fetching captions for {video_id}")
|
| 170 |
+
# Method 1: Try YouTubeTranscriptApi
|
|
|
|
|
|
|
| 171 |
try:
|
| 172 |
+
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
|
| 173 |
+
transcript = None
|
| 174 |
+
try:
|
| 175 |
+
transcript = transcript_list.find_transcript(['en'])
|
| 176 |
+
except:
|
| 177 |
+
# Get any available
|
| 178 |
+
for t in transcript_list:
|
| 179 |
+
transcript = t
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
if transcript:
|
| 183 |
+
transcript_data = transcript.fetch()
|
| 184 |
+
# Handle both dictionary and object formats
|
| 185 |
+
transcript_text = " ".join([
|
| 186 |
+
entry.text if hasattr(entry, 'text') else entry['text']
|
| 187 |
+
for entry in transcript_data
|
| 188 |
+
])
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"YouTubeTranscriptApi failed: {e}. Trying yt-dlp fallback...")
|
| 191 |
|
| 192 |
+
# Method 2: Fallback to yt-dlp
|
| 193 |
+
import requests
|
| 194 |
+
ydl_opts = {
|
| 195 |
+
'skip_download': True,
|
| 196 |
+
'writesubtitles': True,
|
| 197 |
+
'writeautomaticsub': True,
|
| 198 |
+
'subtitleslangs': ['en'],
|
| 199 |
+
'quiet': True
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 203 |
+
info = ydl.extract_info(url, download=False)
|
| 204 |
+
subtitles = info.get('subtitles', {})
|
| 205 |
+
auto_captions = info.get('automatic_captions', {})
|
| 206 |
+
|
| 207 |
+
sub_url = None
|
| 208 |
+
if 'en' in subtitles:
|
| 209 |
+
for fmt in subtitles['en']:
|
| 210 |
+
if fmt['ext'] == 'json3':
|
| 211 |
+
sub_url = fmt['url']
|
| 212 |
+
break
|
| 213 |
+
if not sub_url and subtitles['en']:
|
| 214 |
+
sub_url = subtitles['en'][0]['url']
|
| 215 |
+
elif 'en' in auto_captions:
|
| 216 |
+
for fmt in auto_captions['en']:
|
| 217 |
+
if fmt['ext'] == 'json3':
|
| 218 |
+
sub_url = fmt['url']
|
| 219 |
+
break
|
| 220 |
+
if not sub_url and auto_captions['en']:
|
| 221 |
+
sub_url = auto_captions['en'][0]['url']
|
| 222 |
+
|
| 223 |
+
if sub_url:
|
| 224 |
+
print(f"Fetching captions from: {sub_url}")
|
| 225 |
+
r = requests.get(sub_url)
|
| 226 |
+
data = r.json()
|
| 227 |
+
events = data.get('events', [])
|
| 228 |
+
for event in events:
|
| 229 |
+
if 'segs' in event:
|
| 230 |
+
for seg in event['segs']:
|
| 231 |
+
if 'utf8' in seg:
|
| 232 |
+
transcript_text += seg['utf8']
|
| 233 |
+
transcript_text += " "
|
| 234 |
+
else:
|
| 235 |
+
raise Exception("No captions found via yt-dlp")
|
| 236 |
+
|
| 237 |
+
if not transcript_text:
|
| 238 |
+
return {"final_report": "No captions available for this video.", "steps": ["No captions found"]}
|
| 239 |
+
|
| 240 |
except Exception as e:
|
| 241 |
+
print(f"Caption Error: {e}")
|
| 242 |
+
import traceback
|
| 243 |
+
traceback.print_exc()
|
| 244 |
+
return {"final_report": f"Error fetching captions: {e}. \n\nPossible reasons:\n1. Video has no captions.\n2. Network restrictions.\n3. Video is private.", "steps": ["Failed to fetch captions"]}
|
| 245 |
|
| 246 |
# Generate Title
|
| 247 |
system = """You are a YouTube Expert. Analyze the provided video transcript and generate 3 catchy, AI-enhanced title options.
|
|
|
|
| 261 |
)
|
| 262 |
|
| 263 |
chain = prompt | llm | StrOutputParser()
|
| 264 |
+
try:
|
| 265 |
+
raw_titles = chain.invoke({"transcript": transcript_text[:10000], "task": task}) # Increased limit
|
| 266 |
+
except Exception as e:
|
| 267 |
+
raw_titles = "VIRAL: Error generating titles\nSEO: Error\nPROFESSIONAL: Error"
|
| 268 |
+
print(f"Title generation error: {e}")
|
| 269 |
|
| 270 |
# Manually construct a beautifully formatted report with MAXIMUM SPACING
|
| 271 |
report = "# YouTube Video Analysis\n\n\n"
|
|
|
|
| 285 |
if "caption" in task.lower() or "transcript" in task.lower():
|
| 286 |
report += "---\n\n\n"
|
| 287 |
report += "## 📝 Full Captions\n\n\n"
|
| 288 |
+
report += f"```text\n{transcript_text[:5000]}...\n```\n\n(Truncated for display)\n\n"
|
| 289 |
else:
|
| 290 |
report += "---\n\n\n"
|
| 291 |
report += "> **Note:** Captions are available for this video! Add 'with captions' to your request to see them.\n\n\n"
|
| 292 |
|
| 293 |
+
# Add summary of video content
|
| 294 |
+
summary_system = "Summarize the following video transcript in 3-5 bullet points."
|
| 295 |
+
summary_prompt = ChatPromptTemplate.from_messages([("system", summary_system), ("human", "{transcript}")])
|
| 296 |
+
summary_chain = summary_prompt | llm | StrOutputParser()
|
| 297 |
+
try:
|
| 298 |
+
summary = summary_chain.invoke({"transcript": transcript_text[:10000]})
|
| 299 |
+
report += "## 📹 Video Summary\n\n" + summary + "\n\n"
|
| 300 |
+
except:
|
| 301 |
+
pass
|
| 302 |
+
|
| 303 |
return {
|
| 304 |
"final_report": report,
|
| 305 |
"youtube_captions": transcript_text,
|
|
|
|
| 318 |
|
| 319 |
# Try to get relevant context from vector store
|
| 320 |
try:
|
| 321 |
+
conversation_id = state.get("conversation_id")
|
| 322 |
+
filter_dict = {"conversation_id": conversation_id} if conversation_id else None
|
| 323 |
+
|
| 324 |
+
# Use similarity_search directly to support filtering
|
| 325 |
+
docs = vector_store.similarity_search(task, k=3, filter=filter_dict)
|
| 326 |
context = "\n\n".join([d.page_content for d in docs]) if docs else ""
|
| 327 |
except Exception as e:
|
| 328 |
print(f"Retriever error: {e}")
|
|
|
|
| 330 |
|
| 331 |
# Check if we should do a quick web search (for real-time info)
|
| 332 |
# If context is empty OR if the query implies real-time data
|
| 333 |
+
real_time_keywords = ["price", "current", "news", "latest", "today", "now", "live", "rate", "stock", "weather", "forecast", "score", "result", "vs", "when", "where", "who", "what"]
|
| 334 |
should_search = any(k in task.lower() for k in real_time_keywords)
|
| 335 |
|
| 336 |
web_context = ""
|
|
|
|
| 365 |
4. If the user asks for "price", "news", or "current" info, prioritize the Web Search Results.
|
| 366 |
5. Keep responses focused. Do NOT write a long report.
|
| 367 |
6. If the question requires extensive research, suggest the user ask for "deep research".
|
| 368 |
+
7. If you don't know the answer and have no context, use your general knowledge to answer as best as possible.
|
| 369 |
"""
|
| 370 |
|
| 371 |
if full_context:
|
|
|
|
| 441 |
steps_log.append(f"Researching: {step}")
|
| 442 |
|
| 443 |
# 1. Try Vector Store first
|
| 444 |
+
conversation_id = state.get("conversation_id")
|
| 445 |
+
filter_dict = {"conversation_id": conversation_id} if conversation_id else None
|
| 446 |
+
|
| 447 |
+
docs = vector_store.similarity_search(step, k=3, filter=filter_dict)
|
| 448 |
if docs:
|
| 449 |
context = "\n".join([d.page_content for d in docs])
|
| 450 |
content.append(f"Source: Local Documents\nTopic: {step}\nContent: {context}")
|
|
|
|
| 484 |
2. **Headers**: Use headers (##, ###) to organize sections. **IMPORTANT**: Always add a blank line before and after every header.
|
| 485 |
3. **Content**: Synthesize the information. Do not just list facts.
|
| 486 |
4. **Formatting**:
|
| 487 |
+
- Use **bold** for key terms.
|
| 488 |
+
- Use bullet points for lists (ensure there is a blank line before the list starts).
|
| 489 |
+
- Use > Blockquotes for important summaries.
|
| 490 |
5. **Citations**: If the notes mention specific sources, cite them.
|
| 491 |
6. **Conclusion**: End with a strong conclusion.
|
| 492 |
|
api_test.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
BASE_URL = "http://127.0.0.1:7860"
|
| 5 |
+
|
| 6 |
+
def test_api():
|
| 7 |
+
print("--- Starting API Test ---")
|
| 8 |
+
|
| 9 |
+
# 1. Create a dummy file
|
| 10 |
+
test_file = "api_test_doc.txt"
|
| 11 |
+
with open(test_file, "w") as f:
|
| 12 |
+
f.write("The API secret is 99999. Do not share this.")
|
| 13 |
+
|
| 14 |
+
conversation_id = "api_chat_1"
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
# 2. Upload File
|
| 18 |
+
print(f"\n1. Uploading to {conversation_id}...")
|
| 19 |
+
with open(test_file, "rb") as f:
|
| 20 |
+
files = {"file": f}
|
| 21 |
+
data = {"conversation_id": conversation_id}
|
| 22 |
+
response = requests.post(f"{BASE_URL}/api/upload", files=files, data=data)
|
| 23 |
+
|
| 24 |
+
print(f"Upload Status: {response.status_code}")
|
| 25 |
+
print(f"Upload Response: {response.json()}")
|
| 26 |
+
|
| 27 |
+
if response.status_code != 200:
|
| 28 |
+
print("❌ Upload Failed")
|
| 29 |
+
return
|
| 30 |
+
|
| 31 |
+
# 3. Chat (Ask about the file)
|
| 32 |
+
print(f"\n2. Asking about the file in {conversation_id}...")
|
| 33 |
+
chat_data = {
|
| 34 |
+
"message": "What is the API secret?",
|
| 35 |
+
"history": [],
|
| 36 |
+
"conversation_id": conversation_id,
|
| 37 |
+
"user_id": "test_user"
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
response = requests.post(f"{BASE_URL}/api/chat", json=chat_data)
|
| 41 |
+
print(f"Chat Status: {response.status_code}")
|
| 42 |
+
result = response.json()
|
| 43 |
+
print(f"Chat Response: {result.get('response')}")
|
| 44 |
+
|
| 45 |
+
if "99999" in result.get('response', ''):
|
| 46 |
+
print("✅ Success: AI found the secret!")
|
| 47 |
+
else:
|
| 48 |
+
print("❌ Failure: AI did not find the secret.")
|
| 49 |
+
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Test Failed: {e}")
|
| 52 |
+
finally:
|
| 53 |
+
if os.path.exists(test_file):
|
| 54 |
+
os.remove(test_file)
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
test_api()
|
check_yt.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import yt_dlp
|
| 3 |
+
import requests
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
video_id = "2dGB9Fo4hnU"
|
| 7 |
+
url = f"https://www.youtube.com/watch?v={video_id}"
|
| 8 |
+
|
| 9 |
+
print("\n--- Method 2: yt-dlp ---")
|
| 10 |
+
try:
|
| 11 |
+
ydl_opts = {
|
| 12 |
+
'skip_download': True,
|
| 13 |
+
'writesubtitles': True,
|
| 14 |
+
'writeautomaticsub': True,
|
| 15 |
+
'subtitleslangs': ['en'],
|
| 16 |
+
'quiet': True
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 20 |
+
info = ydl.extract_info(url, download=False)
|
| 21 |
+
|
| 22 |
+
subtitles = info.get('subtitles', {})
|
| 23 |
+
auto_captions = info.get('automatic_captions', {})
|
| 24 |
+
|
| 25 |
+
sub_url = None
|
| 26 |
+
if 'en' in subtitles:
|
| 27 |
+
print("Found manual English subtitles")
|
| 28 |
+
# Prefer json3
|
| 29 |
+
for fmt in subtitles['en']:
|
| 30 |
+
if fmt['ext'] == 'json3':
|
| 31 |
+
sub_url = fmt['url']
|
| 32 |
+
break
|
| 33 |
+
if not sub_url:
|
| 34 |
+
sub_url = subtitles['en'][0]['url']
|
| 35 |
+
elif 'en' in auto_captions:
|
| 36 |
+
print("Found auto English captions")
|
| 37 |
+
for fmt in auto_captions['en']:
|
| 38 |
+
if fmt['ext'] == 'json3':
|
| 39 |
+
sub_url = fmt['url']
|
| 40 |
+
break
|
| 41 |
+
if not sub_url:
|
| 42 |
+
sub_url = auto_captions['en'][0]['url']
|
| 43 |
+
|
| 44 |
+
if sub_url:
|
| 45 |
+
print(f"Fetching: {sub_url}")
|
| 46 |
+
r = requests.get(sub_url)
|
| 47 |
+
data = r.json()
|
| 48 |
+
# print(json.dumps(data, indent=2)[:500])
|
| 49 |
+
|
| 50 |
+
# Parse json3
|
| 51 |
+
events = data.get('events', [])
|
| 52 |
+
text = ""
|
| 53 |
+
for event in events:
|
| 54 |
+
if 'segs' in event:
|
| 55 |
+
for seg in event['segs']:
|
| 56 |
+
if 'utf8' in seg:
|
| 57 |
+
text += seg['utf8']
|
| 58 |
+
text += " "
|
| 59 |
+
print(f"Extracted text length: {len(text)}")
|
| 60 |
+
print(f"Preview: {text[:100]}")
|
| 61 |
+
else:
|
| 62 |
+
print("No English subtitles found")
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"Method 2 failed: {e}")
|
database.py
CHANGED
|
@@ -28,6 +28,7 @@ class Conversation(Base):
|
|
| 28 |
title = Column(String, nullable=False)
|
| 29 |
created_at = Column(DateTime, default=datetime.utcnow)
|
| 30 |
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
|
|
|
| 31 |
|
| 32 |
# Relationship
|
| 33 |
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
|
|
|
|
| 28 |
title = Column(String, nullable=False)
|
| 29 |
created_at = Column(DateTime, default=datetime.utcnow)
|
| 30 |
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
| 31 |
+
summary = Column(Text, nullable=True) # Short summary of the conversation
|
| 32 |
|
| 33 |
# Relationship
|
| 34 |
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
|
main.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends
|
| 2 |
from fastapi.concurrency import run_in_threadpool
|
| 3 |
from fastapi.staticfiles import StaticFiles
|
| 4 |
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -58,27 +58,23 @@ class ConversationResponse(BaseModel):
|
|
| 58 |
created_at: str
|
| 59 |
updated_at: str
|
| 60 |
message_count: int = 0
|
|
|
|
| 61 |
|
| 62 |
@app.post("/api/upload")
|
| 63 |
-
async def upload_file(file: UploadFile = File(...)):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
try:
|
| 65 |
# Save file temporarily
|
| 66 |
file_path = f"temp_{file.filename}"
|
| 67 |
with open(file_path, "wb") as buffer:
|
| 68 |
shutil.copyfileobj(file.file, buffer)
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
else:
|
| 74 |
-
loader = TextLoader(file_path)
|
| 75 |
-
|
| 76 |
-
docs = loader.load()
|
| 77 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
| 78 |
-
splits = text_splitter.split_documents(docs)
|
| 79 |
-
|
| 80 |
-
# Add to Vector Store
|
| 81 |
-
vector_store.add_documents(splits)
|
| 82 |
|
| 83 |
# Cleanup
|
| 84 |
os.remove(file_path)
|
|
@@ -87,8 +83,61 @@ async def upload_file(file: UploadFile = File(...)):
|
|
| 87 |
except Exception as e:
|
| 88 |
raise HTTPException(status_code=500, detail=str(e))
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
@app.post("/api/chat")
|
| 91 |
-
def chat_endpoint(request: ChatRequest, db: Session = Depends(get_db)):
|
| 92 |
try:
|
| 93 |
# Convert history to LangChain messages
|
| 94 |
messages = []
|
|
@@ -108,13 +157,13 @@ def chat_endpoint(request: ChatRequest, db: Session = Depends(get_db)):
|
|
| 108 |
"plan": [],
|
| 109 |
"content": [],
|
| 110 |
"revision_number": 0,
|
| 111 |
-
"max_revisions":
|
| 112 |
-
"final_report": "",
|
| 113 |
"steps": [],
|
| 114 |
-
"messages":
|
| 115 |
-
"deep_research": False, # Will be set by router
|
| 116 |
"youtube_url": "",
|
| 117 |
-
"youtube_captions": ""
|
|
|
|
|
|
|
| 118 |
}
|
| 119 |
|
| 120 |
result = agent_app.invoke(inputs)
|
|
@@ -163,6 +212,14 @@ def chat_endpoint(request: ChatRequest, db: Session = Depends(get_db)):
|
|
| 163 |
conversation.updated_at = datetime.utcnow()
|
| 164 |
|
| 165 |
db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
except Exception as db_error:
|
| 167 |
print(f"Database error: {db_error}")
|
| 168 |
db.rollback()
|
|
@@ -233,7 +290,8 @@ async def create_conversation(conv: ConversationCreate, db: Session = Depends(ge
|
|
| 233 |
"title": new_conv.title,
|
| 234 |
"created_at": new_conv.created_at.isoformat(),
|
| 235 |
"updated_at": new_conv.updated_at.isoformat(),
|
| 236 |
-
"message_count": 0
|
|
|
|
| 237 |
}
|
| 238 |
except Exception as e:
|
| 239 |
db.rollback()
|
|
@@ -262,7 +320,8 @@ async def get_conversations(user_id: str, db: Session = Depends(get_db)):
|
|
| 262 |
"title": conv.title,
|
| 263 |
"created_at": conv.created_at.isoformat(),
|
| 264 |
"updated_at": conv.updated_at.isoformat(),
|
| 265 |
-
"message_count": message_count
|
|
|
|
| 266 |
})
|
| 267 |
|
| 268 |
return result
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, BackgroundTasks, Form
|
| 2 |
from fastapi.concurrency import run_in_threadpool
|
| 3 |
from fastapi.staticfiles import StaticFiles
|
| 4 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 58 |
created_at: str
|
| 59 |
updated_at: str
|
| 60 |
message_count: int = 0
|
| 61 |
+
summary: Optional[str] = None
|
| 62 |
|
| 63 |
@app.post("/api/upload")
|
| 64 |
+
async def upload_file(file: UploadFile = File(...), conversation_id: str = Form(...)):
|
| 65 |
+
print(f"DEBUG: Uploading file {file.filename} to conversation {conversation_id}")
|
| 66 |
+
if not conversation_id or conversation_id == "null" or conversation_id == "undefined":
|
| 67 |
+
print("ERROR: Invalid conversation_id received in upload_file")
|
| 68 |
+
raise HTTPException(status_code=400, detail="Please start a conversation first!")
|
| 69 |
try:
|
| 70 |
# Save file temporarily
|
| 71 |
file_path = f"temp_{file.filename}"
|
| 72 |
with open(file_path, "wb") as buffer:
|
| 73 |
shutil.copyfileobj(file.file, buffer)
|
| 74 |
|
| 75 |
+
# Process file
|
| 76 |
+
from agent import upload_file as agent_upload
|
| 77 |
+
splits = agent_upload(file_path, conversation_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
# Cleanup
|
| 80 |
os.remove(file_path)
|
|
|
|
| 83 |
except Exception as e:
|
| 84 |
raise HTTPException(status_code=500, detail=str(e))
|
| 85 |
|
| 86 |
+
@app.delete("/api/vector_store")
|
| 87 |
+
async def clear_vector_store_endpoint():
|
| 88 |
+
try:
|
| 89 |
+
from agent import clear_vector_store
|
| 90 |
+
success = clear_vector_store()
|
| 91 |
+
if success:
|
| 92 |
+
return {"status": "success", "message": "Vector store cleared"}
|
| 93 |
+
else:
|
| 94 |
+
raise HTTPException(status_code=500, detail="Failed to clear vector store")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 97 |
+
|
| 98 |
+
async def generate_conversation_summary(conversation_id: str, db: Session):
|
| 99 |
+
"""Background task to generate a summary for a conversation."""
|
| 100 |
+
try:
|
| 101 |
+
# Get messages
|
| 102 |
+
messages = db.query(DBMessage).filter(
|
| 103 |
+
DBMessage.conversation_id == conversation_id
|
| 104 |
+
).order_by(DBMessage.created_at).limit(10).all() # Limit to first 10 for summary
|
| 105 |
+
|
| 106 |
+
if not messages:
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
conversation_text = "\n".join([f"{msg.role}: {msg.content}" for msg in messages])
|
| 110 |
+
|
| 111 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 112 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 113 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 114 |
+
|
| 115 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
|
| 116 |
+
|
| 117 |
+
system = """You are a helpful assistant. Create a very short, 1-sentence summary (max 10 words) of this conversation topic.
|
| 118 |
+
Example: "Python script debugging", "Recipe for chocolate cake", "Travel plans to Japan".
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 122 |
+
("system", system),
|
| 123 |
+
("human", "Conversation:\n{text}")
|
| 124 |
+
])
|
| 125 |
+
|
| 126 |
+
chain = prompt | llm | StrOutputParser()
|
| 127 |
+
summary = chain.invoke({"text": conversation_text})
|
| 128 |
+
|
| 129 |
+
# Update conversation
|
| 130 |
+
conversation = db.query(Conversation).filter(Conversation.id == conversation_id).first()
|
| 131 |
+
if conversation:
|
| 132 |
+
conversation.summary = summary.strip()
|
| 133 |
+
db.commit()
|
| 134 |
+
print(f"Generated summary for {conversation_id}: {summary}")
|
| 135 |
+
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"Error generating summary: {e}")
|
| 138 |
+
|
| 139 |
@app.post("/api/chat")
|
| 140 |
+
def chat_endpoint(request: ChatRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
|
| 141 |
try:
|
| 142 |
# Convert history to LangChain messages
|
| 143 |
messages = []
|
|
|
|
| 157 |
"plan": [],
|
| 158 |
"content": [],
|
| 159 |
"revision_number": 0,
|
| 160 |
+
"max_revisions": 2,
|
|
|
|
| 161 |
"steps": [],
|
| 162 |
+
"messages": [HumanMessage(content=request.message)],
|
|
|
|
| 163 |
"youtube_url": "",
|
| 164 |
+
"youtube_captions": "",
|
| 165 |
+
"deep_research": False, # Will be set by router
|
| 166 |
+
"conversation_id": request.conversation_id
|
| 167 |
}
|
| 168 |
|
| 169 |
result = agent_app.invoke(inputs)
|
|
|
|
| 212 |
conversation.updated_at = datetime.utcnow()
|
| 213 |
|
| 214 |
db.commit()
|
| 215 |
+
|
| 216 |
+
# Trigger summary generation if it's the first few messages or summary is missing
|
| 217 |
+
# We can check message count or just do it periodically
|
| 218 |
+
# For simplicity, let's do it if message count is small (< 5) or summary is None
|
| 219 |
+
message_count = db.query(DBMessage).filter(DBMessage.conversation_id == request.conversation_id).count()
|
| 220 |
+
if message_count <= 4 or not conversation.summary:
|
| 221 |
+
background_tasks.add_task(generate_conversation_summary, request.conversation_id, db)
|
| 222 |
+
|
| 223 |
except Exception as db_error:
|
| 224 |
print(f"Database error: {db_error}")
|
| 225 |
db.rollback()
|
|
|
|
| 290 |
"title": new_conv.title,
|
| 291 |
"created_at": new_conv.created_at.isoformat(),
|
| 292 |
"updated_at": new_conv.updated_at.isoformat(),
|
| 293 |
+
"message_count": 0,
|
| 294 |
+
"summary": None
|
| 295 |
}
|
| 296 |
except Exception as e:
|
| 297 |
db.rollback()
|
|
|
|
| 320 |
"title": conv.title,
|
| 321 |
"created_at": conv.created_at.isoformat(),
|
| 322 |
"updated_at": conv.updated_at.isoformat(),
|
| 323 |
+
"message_count": message_count,
|
| 324 |
+
"summary": conv.summary
|
| 325 |
})
|
| 326 |
|
| 327 |
return result
|
rag_test.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from agent import upload_file, vector_store
|
| 5 |
+
|
| 6 |
+
# Load environment variables
|
| 7 |
+
load_dotenv(".env", override=True)
|
| 8 |
+
load_dotenv("../.env", override=False)
|
| 9 |
+
|
| 10 |
+
def test_rag():
|
| 11 |
+
print("--- Starting RAG Test ---")
|
| 12 |
+
|
| 13 |
+
# Create a dummy test file
|
| 14 |
+
test_file = "rag_test_doc.txt"
|
| 15 |
+
with open(test_file, "w") as f:
|
| 16 |
+
f.write("The secret code is 12345. This is a confidential document for Project X.")
|
| 17 |
+
|
| 18 |
+
conversation_id_1 = "chat_1"
|
| 19 |
+
conversation_id_2 = "chat_2"
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
# 1. Upload to Chat 1
|
| 23 |
+
print(f"\n1. Uploading to {conversation_id_1}...")
|
| 24 |
+
upload_file(test_file, conversation_id_1)
|
| 25 |
+
|
| 26 |
+
# 2. Retrieve from Chat 1 (Should find it)
|
| 27 |
+
print(f"\n2. Searching in {conversation_id_1}...")
|
| 28 |
+
results_1 = vector_store.similarity_search("secret code", k=1, filter={"conversation_id": conversation_id_1})
|
| 29 |
+
if results_1:
|
| 30 |
+
print(f"✅ Found: {results_1[0].page_content}")
|
| 31 |
+
else:
|
| 32 |
+
print("❌ Not found (Unexpected)")
|
| 33 |
+
|
| 34 |
+
# 3. Retrieve from Chat 2 (Should NOT find it)
|
| 35 |
+
print(f"\n3. Searching in {conversation_id_2}...")
|
| 36 |
+
results_2 = vector_store.similarity_search("secret code", k=1, filter={"conversation_id": conversation_id_2})
|
| 37 |
+
if not results_2:
|
| 38 |
+
print("✅ Not found (Expected)")
|
| 39 |
+
else:
|
| 40 |
+
print(f"❌ Found (Unexpected): {results_2[0].page_content}")
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Test Failed: {e}")
|
| 44 |
+
finally:
|
| 45 |
+
# Cleanup
|
| 46 |
+
if os.path.exists(test_file):
|
| 47 |
+
os.remove(test_file)
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
test_rag()
|