rag_korean_manufacturing_docs / src /rag_demo_standalone.py
A7m0d's picture
Upload folder using huggingface_hub
7dfe46c verified
import streamlit as st
import os
import sys
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from dotenv import load_dotenv
import requests
import json
import time
# Load environment variables
load_dotenv()
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from logger.custom_logger import CustomLoggerTracker
custom_log = CustomLoggerTracker()
logger = custom_log.get_logger("rag_demo_standalone")
except ImportError:
# Fallback to standard logging if custom logger not available
logger = logging.getLogger("rag_demo_standalone")
@dataclass
class Document:
"""Document structure for RAG system."""
content: str
metadata: Dict[str, Any]
embedding: Optional[List[float]] = None
@dataclass
class RAGResult:
"""RAG query result."""
query: str
answer: str
relevant_documents: List[Document]
processing_time: float
class SimpleVectorStore:
"""Simple in-memory vector store for demonstration."""
def __init__(self):
self.documents: List[Document] = []
self.embeddings: List[List[float]] = []
def add_documents(self, documents: List[Document]):
"""Add documents to the vector store."""
self.documents.extend(documents)
for doc in documents:
if doc.embedding:
self.embeddings.append(doc.embedding)
def similarity_search(self, query_embedding: List[float], top_k: int = 5) -> List[Document]:
"""Find most similar documents using cosine similarity."""
if not self.embeddings or not query_embedding:
return []
similarities = []
for i, doc_embedding in enumerate(self.embeddings):
if doc_embedding:
similarity = self._cosine_similarity(query_embedding, doc_embedding)
similarities.append((similarity, self.documents[i]))
# Sort by similarity and return top_k
similarities.sort(key=lambda x: x[0], reverse=True)
return [doc for _, doc in similarities[:top_k]]
def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
"""Calculate cosine similarity between two vectors."""
if len(a) != len(b):
return 0.0
dot_product = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot_product / (norm_a * norm_b)
class EmbeddingSystem:
"""SiliconFlow API client for embeddings and chat completion."""
def __init__(self, api_key: str, base_url: str = "https://api.siliconflow.cn/v1"):
self.api_key = api_key
self.base_url = base_url.rstrip('/')
self.headers = {
'Authorization': f'Bearer {api_key}',
'Content-Type': 'application/json'
}
def generate_embeddings(self, texts: List[str],
model: str = "BAAI/bge-large-zh-v1.5") -> List[List[float]]:
"""Generate embeddings for texts."""
try:
payload = {
"model": model,
"input": texts,
"encoding_format": "float"
}
response = requests.post(
f"{self.base_url}/embeddings",
json=payload,
headers=self.headers,
timeout=30
)
if response.status_code == 200:
data = response.json()
return [item['embedding'] for item in data.get('data', [])]
else:
logger.error(f"Embedding API error: {response.status_code} - {response.text}")
return []
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
return []
def rerank_documents(self, query: str, documents: List[str],
model: str = "BAAI/bge-reranker-large",
top_k: int = 5) -> List[Dict]:
"""Rerank documents based on query relevance."""
try:
payload = {
"model": model,
"query": query,
"documents": documents,
"top_k": top_k,
"return_documents": True
}
response = requests.post(
f"{self.base_url}/rerank",
json=payload,
headers=self.headers,
timeout=30
)
if response.status_code == 200:
data = response.json()
return data.get('results', [])
else:
logger.error(f"Rerank API error: {response.status_code} - {response.text}")
return []
except Exception as e:
logger.error(f"Reranking failed: {e}")
return []
def chat_completion(self, messages: List[Dict[str, str]],
model: str = "Qwen/Qwen2.5-7B-Instruct") -> str:
"""Generate chat completion."""
try:
payload = {
"model": model,
"messages": messages,
"temperature": 0.7,
"max_tokens": 1000
}
response = requests.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=self.headers,
timeout=60
)
if response.status_code == 200:
data = response.json()
return data['choices'][0]['message']['content']
else:
logger.error(f"Chat completion API error: {response.status_code} - {response.text}")
return "μ£„μ†‘ν•©λ‹ˆλ‹€. 응닡을 생성할 수 μ—†μŠ΅λ‹ˆλ‹€."
except Exception as e:
logger.error(f"Chat completion failed: {e}")
return "μ£„μ†‘ν•©λ‹ˆλ‹€. 응닡을 생성할 수 μ—†μŠ΅λ‹ˆλ‹€."
class RAGSystem:
"""Complete RAG system using SiliconFlow."""
def __init__(self, api_key: str):
self.client = EmbeddingSystem(api_key)
self.vector_store = SimpleVectorStore()
logger.info("RAG System initialized")
def add_documents(self, texts: List[str], metadatas: Optional[List[Dict]] = None):
"""Add documents to the RAG system."""
if not metadatas:
metadatas = [{"source": f"doc_{i}"} for i in range(len(texts))]
logger.info(f"Adding {len(texts)} documents...")
# Generate embeddings
embeddings = self.client.generate_embeddings(texts)
# Create document objects
documents = []
for i, (text, metadata) in enumerate(zip(texts, metadatas)):
embedding = embeddings[i] if i < len(embeddings) else None
doc = Document(content=text, metadata=metadata, embedding=embedding)
documents.append(doc)
# Add to vector store
self.vector_store.add_documents(documents)
logger.info(f"Successfully added {len(documents)} documents")
def query(self, query: str, top_k: int = 5, use_reranking: bool = True) -> RAGResult:
"""Query the RAG system."""
start_time = time.time()
# Generate query embedding
query_embeddings = self.client.generate_embeddings([query])
query_embedding = query_embeddings[0] if query_embeddings else []
if not query_embedding:
logger.error("Failed to generate query embedding")
return RAGResult(
query=query,
answer="μ£„μ†‘ν•©λ‹ˆλ‹€. 쿼리λ₯Ό μ²˜λ¦¬ν•  수 μ—†μŠ΅λ‹ˆλ‹€.",
relevant_documents=[],
processing_time=time.time() - start_time
)
# Find similar documents
similar_docs = self.vector_store.similarity_search(query_embedding, top_k * 2)
if not similar_docs:
return RAGResult(
query=query,
answer="κ΄€λ ¨ λ¬Έμ„œλ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.",
relevant_documents=[],
processing_time=time.time() - start_time
)
# Optional reranking
if use_reranking and len(similar_docs) > top_k:
doc_texts = [doc.content for doc in similar_docs]
rerank_results = self.client.rerank_documents(query, doc_texts, top_k=top_k)
if rerank_results:
# Reorder documents based on reranking
reranked_docs = []
for result in rerank_results:
doc_idx = result.get('index', 0)
if doc_idx < len(similar_docs):
reranked_docs.append(similar_docs[doc_idx])
similar_docs = reranked_docs
# Limit to top_k
relevant_docs = similar_docs[:top_k]
# Generate answer using context
context = "\n\n".join([doc.content for doc in relevant_docs])
answer = self._generate_answer(query, context)
processing_time = time.time() - start_time
return RAGResult(
query=query,
answer=answer,
relevant_documents=relevant_docs,
processing_time=processing_time
)
def _generate_answer(self, query: str, context: str) -> str:
"""Generate answer using context and query."""
system_prompt = """당신은 ν•œκ΅­μ–΄λ‘œ λ‹΅λ³€ν•˜λŠ” 도움이 λ˜λŠ” μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€.
μ£Όμ–΄μ§„ μ»¨ν…μŠ€νŠΈλ₯Ό λ°”νƒ•μœΌλ‘œ μ§ˆλ¬Έμ— μ •ν™•ν•˜κ³  μœ μš©ν•œ 닡변을 μ œκ³΅ν•΄μ£Όμ„Έμš”.
μ»¨ν…μŠ€νŠΈμ— 정보가 μ—†μœΌλ©΄ 'μ£Όμ–΄μ§„ μ •λ³΄λ‘œλŠ” λ‹΅λ³€ν•˜κΈ° μ–΄λ ΅μŠ΅λ‹ˆλ‹€'라고 λ§ν•΄μ£Όμ„Έμš”."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"μ»¨ν…μŠ€νŠΈ:\n{context}\n\n질문: {query}"}
]
return self.client.chat_completion(messages)
# Sample Korean manufacturing data
SAMPLE_DOCUMENTS = [
"TAB S10 도μž₯ κ³΅μ •μ˜ μˆ˜μœ¨μ€ ν˜„μž¬ 95.2%μž…λ‹ˆλ‹€. λͺ©ν‘œ 수율 94%λ₯Ό μƒνšŒν•˜κ³  있으며, μ§€λ‚œλ‹¬ λŒ€λΉ„ 1.3% ν–₯μƒλ˜μ—ˆμŠ΅λ‹ˆλ‹€.",
"도μž₯ λΌμΈμ—μ„œ λΆˆλŸ‰λ₯ μ΄ 4.8% λ°œμƒν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. μ£Όμš” λΆˆλŸ‰ 원인은 μ˜¨λ„ 편차(45%)와 μŠ΅λ„ λ³€ν™”(30%)μž…λ‹ˆλ‹€.",
"S10 λͺ¨λΈμ˜ 전체 생산 μˆ˜μœ¨μ€ 89.5%둜 λͺ©ν‘œμΉ˜ 88%λ₯Ό μƒνšŒν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. μ›”κ°„ μƒμ‚°λŸ‰μ€ 15,000λŒ€μž…λ‹ˆλ‹€.",
"도μž₯ 라인의 μ˜¨λ„λŠ” 22Β±2℃, μŠ΅λ„λŠ” 45Β±5%둜 μœ μ§€λ˜μ–΄μ•Ό ν•©λ‹ˆλ‹€. ν˜„μž¬ μžλ™ μ œμ–΄ μ‹œμŠ€ν…œμœΌλ‘œ κ΄€λ¦¬λ˜κ³  μžˆμŠ΅λ‹ˆλ‹€.",
"ν’ˆμ§ˆκ΄€λ¦¬ λΆ€μ„œμ—μ„œλŠ” 맀일 3회 μƒ˜ν”Œλ§ 검사λ₯Ό μ‹€μ‹œν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. 검사 ν•­λͺ©μ€ 색상, 광택, λ‘κ»˜μž…λ‹ˆλ‹€.",
"예방 보전 κ³„νšμ— 따라 도μž₯ μ„€λΉ„λŠ” μ£Ό 1회 μ •κΈ° 점검을 μ‹€μ‹œν•©λ‹ˆλ‹€. λ‹€μŒ μ •κΈ° 보전은 λ‹€μŒ μ£Ό ν™”μš”μΌμž…λ‹ˆλ‹€.",
"μ‹ κ·œ 도μž₯ 재료 적용 ν›„ μ ‘μ°©λ ₯이 15% ν–₯μƒλ˜μ—ˆμŠ΅λ‹ˆλ‹€. λΉ„μš©μ€ 10% μ¦κ°€ν–ˆμ§€λ§Œ ν’ˆμ§ˆ κ°œμ„  νš¨κ³Όκ°€ ν½λ‹ˆλ‹€.",
"μž‘μ—…μž κ΅μœ‘μ€ μ›” 2회 μ‹€μ‹œλ˜λ©°, μ•ˆμ „κ΅μœ‘κ³Ό ν’ˆμ§ˆκ΅μœ‘μ„ ν¬ν•¨ν•©λ‹ˆλ‹€. ꡐ윑 참석λ₯ μ€ 98.5%μž…λ‹ˆλ‹€."
]
def run_streamlit_app():
"""Run the Streamlit web application."""
st.set_page_config(page_title="RAG Demo - Korean QA", page_icon="πŸ€–", layout="wide")
st.title("πŸ€– Korean RAG Demo with SiliconFlow")
st.markdown("*Retrieval-Augmented Generation for Korean Manufacturing Q&A*")
# Sidebar configuration
with st.sidebar:
st.header("βš™οΈ Configuration")
api_key = st.text_input(
"SiliconFlow API Key",
value=os.getenv("SILICONFLOW_API_KEY", ""),
type="password",
help="Enter your SiliconFlow API key"
)
use_reranking = st.checkbox("Use Reranking", value=True, help="Use reranking for better results")
top_k = st.slider("Top K Results", min_value=1, max_value=10, value=3)
if st.button("Initialize RAG System"):
if not api_key:
st.error("Please provide SiliconFlow API key")
else:
with st.spinner("Initializing RAG system..."):
try:
rag_system = RAGSystem(api_key)
rag_system.add_documents(SAMPLE_DOCUMENTS)
st.session_state['rag_system'] = rag_system
st.success("RAG system initialized successfully!")
except Exception as e:
st.error(f"Failed to initialize RAG system: {e}")
# Main interface
if 'rag_system' in st.session_state:
st.header("πŸ’¬ Ask Questions")
# Sample questions
sample_questions = [
"TAB S10 도μž₯ 곡정 수율이 μ–΄λ–»κ²Œ λ˜λ‚˜μš”?",
"도μž₯ 라인의 λΆˆλŸ‰λ₯ κ³Ό μ£Όμš” 원인은?",
"ν’ˆμ§ˆ κ²€μ‚¬λŠ” μ–΄λ–»κ²Œ μ§„ν–‰λ˜λ‚˜μš”?",
"예방 보전 κ³„νšμ— λŒ€ν•΄ μ•Œλ €μ£Όμ„Έμš”"
]
col1, col2 = st.columns([3, 1])
with col1:
query = st.text_input("μ§ˆλ¬Έμ„ μž…λ ₯ν•˜μ„Έμš”:", placeholder="예: TAB S10 수율이 μ–΄λ–»κ²Œ λ˜λ‚˜μš”?")
with col2:
st.markdown("**μƒ˜ν”Œ 질문:**")
for i, sample in enumerate(sample_questions):
if st.button(f"Q{i+1}", key=f"sample_{i}", help=sample):
st.rerun()
if query:
with st.spinner("Searching and generating answer..."):
try:
result = st.session_state['rag_system'].query(
query,
top_k=top_k,
use_reranking=use_reranking
)
# Display results
st.header("πŸ“‹ Answer")
st.write(result.answer)
st.header("πŸ“„ Relevant Documents")
for i, doc in enumerate(result.relevant_documents):
with st.expander(f"Document {i+1} - {doc.metadata.get('source', 'Unknown')}"):
st.write(doc.content)
# Stats
st.sidebar.metric("Processing Time", f"{result.processing_time:.2f}s")
st.sidebar.metric("Documents Found", len(result.relevant_documents))
except Exception as e:
st.error(f"Query failed: {e}")
else:
st.info("πŸ‘ˆ Please initialize the RAG system using the sidebar")
# Show sample documents
st.header("πŸ“š Sample Documents")
st.markdown("The system includes these sample manufacturing documents:")
for i, doc in enumerate(SAMPLE_DOCUMENTS, 1):
st.markdown(f"**{i}.** {doc}")
def run_cli_demo():
"""Run command line interface demo."""
print("πŸ€– Korean RAG Demo - CLI Mode")
print("=" * 50)
# Get API key
api_key = os.getenv("SILICONFLOW_API_KEY")
if not api_key:
api_key = input("Enter your SiliconFlow API key: ")
if not api_key:
print("❌ API key is required")
return
try:
# Initialize RAG system
print("πŸ”„ Initializing RAG system...")
rag_system = RAGSystem(api_key)
rag_system.add_documents(SAMPLE_DOCUMENTS)
print("βœ… RAG system ready!")
# Interactive loop
while True:
print("\n" + "-" * 50)
query = input("μ§ˆλ¬Έμ„ μž…λ ₯ν•˜μ„Έμš” (μ’…λ£Œν•˜λ €λ©΄ 'quit'): ")
if query.lower() in ['quit', 'exit', 'μ’…λ£Œ']:
break
if not query.strip():
continue
print("πŸ” Searching...")
result = rag_system.query(query, top_k=3, use_reranking=True)
print(f"\nπŸ“‹ λ‹΅λ³€:")
print(result.answer)
print(f"\nπŸ“„ κ΄€λ ¨ λ¬Έμ„œλ“€:")
for i, doc in enumerate(result.relevant_documents, 1):
print(f"{i}. {doc.content}")
print(f"\n⏱️ 처리 μ‹œκ°„: {result.processing_time:.2f}초")
except Exception as e:
print(f"❌ Error: {e}")
if __name__ == "__main__":
# Check if running in streamlit context
try:
# This will raise an exception if not in streamlit context
st.session_state
run_streamlit_app()
except:
# Run CLI version
run_cli_demo()