File size: 17,013 Bytes
7dfe46c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import streamlit as st
import os
import sys
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from dotenv import load_dotenv
import requests
import json
import time

# Load environment variables
load_dotenv()
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


try:
    from logger.custom_logger import CustomLoggerTracker
    custom_log = CustomLoggerTracker()
    logger = custom_log.get_logger("rag_demo_standalone")

except ImportError:
    # Fallback to standard logging if custom logger not available
    logger = logging.getLogger("rag_demo_standalone")


@dataclass
class Document:
    """Document structure for RAG system."""
    content: str
    metadata: Dict[str, Any]
    embedding: Optional[List[float]] = None

@dataclass
class RAGResult:
    """RAG query result."""
    query: str
    answer: str
    relevant_documents: List[Document]
    processing_time: float

class SimpleVectorStore:
    """Simple in-memory vector store for demonstration."""
    
    def __init__(self):
        self.documents: List[Document] = []
        self.embeddings: List[List[float]] = []
    
    def add_documents(self, documents: List[Document]):
        """Add documents to the vector store."""
        self.documents.extend(documents)
        for doc in documents:
            if doc.embedding:
                self.embeddings.append(doc.embedding)
    
    def similarity_search(self, query_embedding: List[float], top_k: int = 5) -> List[Document]:
        """Find most similar documents using cosine similarity."""
        if not self.embeddings or not query_embedding:
            return []
        
        similarities = []
        for i, doc_embedding in enumerate(self.embeddings):
            if doc_embedding:
                similarity = self._cosine_similarity(query_embedding, doc_embedding)
                similarities.append((similarity, self.documents[i]))
        
        # Sort by similarity and return top_k
        similarities.sort(key=lambda x: x[0], reverse=True)
        return [doc for _, doc in similarities[:top_k]]
    
    def _cosine_similarity(self, a: List[float], b: List[float]) -> float:
        """Calculate cosine similarity between two vectors."""
        if len(a) != len(b):
            return 0.0
        
        dot_product = sum(x * y for x, y in zip(a, b))
        norm_a = sum(x * x for x in a) ** 0.5
        norm_b = sum(x * x for x in b) ** 0.5
        
        if norm_a == 0 or norm_b == 0:
            return 0.0
        
        return dot_product / (norm_a * norm_b)

class EmbeddingSystem:
    """SiliconFlow API client for embeddings and chat completion."""
    
    def __init__(self, api_key: str, base_url: str = "https://api.siliconflow.cn/v1"):
        self.api_key = api_key
        self.base_url = base_url.rstrip('/')
        self.headers = {
            'Authorization': f'Bearer {api_key}',
            'Content-Type': 'application/json'
        }
    
    def generate_embeddings(self, texts: List[str], 
                          model: str = "BAAI/bge-large-zh-v1.5") -> List[List[float]]:
        """Generate embeddings for texts."""
        try:
            payload = {
                "model": model,
                "input": texts,
                "encoding_format": "float"
            }
            
            response = requests.post(
                f"{self.base_url}/embeddings",
                json=payload,
                headers=self.headers,
                timeout=30
            )
            
            if response.status_code == 200:
                data = response.json()
                return [item['embedding'] for item in data.get('data', [])]
            else:
                logger.error(f"Embedding API error: {response.status_code} - {response.text}")
                return []
                
        except Exception as e:
            logger.error(f"Embedding generation failed: {e}")
            return []
    
    def rerank_documents(self, query: str, documents: List[str], 
                        model: str = "BAAI/bge-reranker-large",
                        top_k: int = 5) -> List[Dict]:
        """Rerank documents based on query relevance."""
        try:
            payload = {
                "model": model,
                "query": query,
                "documents": documents,
                "top_k": top_k,
                "return_documents": True
            }
            
            response = requests.post(
                f"{self.base_url}/rerank",
                json=payload,
                headers=self.headers,
                timeout=30
            )
            
            if response.status_code == 200:
                data = response.json()
                return data.get('results', [])
            else:
                logger.error(f"Rerank API error: {response.status_code} - {response.text}")
                return []
                
        except Exception as e:
            logger.error(f"Reranking failed: {e}")
            return []
    
    def chat_completion(self, messages: List[Dict[str, str]], 
                       model: str = "Qwen/Qwen2.5-7B-Instruct") -> str:
        """Generate chat completion."""
        try:
            payload = {
                "model": model,
                "messages": messages,
                "temperature": 0.7,
                "max_tokens": 1000
            }
            
            response = requests.post(
                f"{self.base_url}/chat/completions",
                json=payload,
                headers=self.headers,
                timeout=60
            )
            
            if response.status_code == 200:
                data = response.json()
                return data['choices'][0]['message']['content']
            else:
                logger.error(f"Chat completion API error: {response.status_code} - {response.text}")
                return "μ£„μ†‘ν•©λ‹ˆλ‹€. 응닡을 생성할 수 μ—†μŠ΅λ‹ˆλ‹€."
                
        except Exception as e:
            logger.error(f"Chat completion failed: {e}")
            return "μ£„μ†‘ν•©λ‹ˆλ‹€. 응닡을 생성할 수 μ—†μŠ΅λ‹ˆλ‹€."

class RAGSystem:
    """Complete RAG system using SiliconFlow."""
    
    def __init__(self, api_key: str):
        self.client = EmbeddingSystem(api_key)
        self.vector_store = SimpleVectorStore()
        logger.info("RAG System initialized")
    
    def add_documents(self, texts: List[str], metadatas: Optional[List[Dict]] = None):
        """Add documents to the RAG system."""
        if not metadatas:
            metadatas = [{"source": f"doc_{i}"} for i in range(len(texts))]
        
        logger.info(f"Adding {len(texts)} documents...")
        
        # Generate embeddings
        embeddings = self.client.generate_embeddings(texts)
        
        # Create document objects
        documents = []
        for i, (text, metadata) in enumerate(zip(texts, metadatas)):
            embedding = embeddings[i] if i < len(embeddings) else None
            doc = Document(content=text, metadata=metadata, embedding=embedding)
            documents.append(doc)
        
        # Add to vector store
        self.vector_store.add_documents(documents)
        logger.info(f"Successfully added {len(documents)} documents")
    
    def query(self, query: str, top_k: int = 5, use_reranking: bool = True) -> RAGResult:
        """Query the RAG system."""
        start_time = time.time()
        
        # Generate query embedding
        query_embeddings = self.client.generate_embeddings([query])
        query_embedding = query_embeddings[0] if query_embeddings else []
        
        if not query_embedding:
            logger.error("Failed to generate query embedding")
            return RAGResult(
                query=query,
                answer="μ£„μ†‘ν•©λ‹ˆλ‹€. 쿼리λ₯Ό μ²˜λ¦¬ν•  수 μ—†μŠ΅λ‹ˆλ‹€.",
                relevant_documents=[],
                processing_time=time.time() - start_time
            )
        
        # Find similar documents
        similar_docs = self.vector_store.similarity_search(query_embedding, top_k * 2)
        
        if not similar_docs:
            return RAGResult(
                query=query,
                answer="κ΄€λ ¨ λ¬Έμ„œλ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.",
                relevant_documents=[],
                processing_time=time.time() - start_time
            )
        
        # Optional reranking
        if use_reranking and len(similar_docs) > top_k:
            doc_texts = [doc.content for doc in similar_docs]
            rerank_results = self.client.rerank_documents(query, doc_texts, top_k=top_k)
            
            if rerank_results:
                # Reorder documents based on reranking
                reranked_docs = []
                for result in rerank_results:
                    doc_idx = result.get('index', 0)
                    if doc_idx < len(similar_docs):
                        reranked_docs.append(similar_docs[doc_idx])
                similar_docs = reranked_docs
        
        # Limit to top_k
        relevant_docs = similar_docs[:top_k]
        
        # Generate answer using context
        context = "\n\n".join([doc.content for doc in relevant_docs])
        answer = self._generate_answer(query, context)
        
        processing_time = time.time() - start_time
        
        return RAGResult(
            query=query,
            answer=answer,
            relevant_documents=relevant_docs,
            processing_time=processing_time
        )
    
    def _generate_answer(self, query: str, context: str) -> str:
        """Generate answer using context and query."""
        system_prompt = """당신은 ν•œκ΅­μ–΄λ‘œ λ‹΅λ³€ν•˜λŠ” 도움이 λ˜λŠ” μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€. 
μ£Όμ–΄μ§„ μ»¨ν…μŠ€νŠΈλ₯Ό λ°”νƒ•μœΌλ‘œ μ§ˆλ¬Έμ— μ •ν™•ν•˜κ³  μœ μš©ν•œ 닡변을 μ œκ³΅ν•΄μ£Όμ„Έμš”.
μ»¨ν…μŠ€νŠΈμ— 정보가 μ—†μœΌλ©΄ 'μ£Όμ–΄μ§„ μ •λ³΄λ‘œλŠ” λ‹΅λ³€ν•˜κΈ° μ–΄λ ΅μŠ΅λ‹ˆλ‹€'라고 λ§ν•΄μ£Όμ„Έμš”."""
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"μ»¨ν…μŠ€νŠΈ:\n{context}\n\n질문: {query}"}
        ]
        
        return self.client.chat_completion(messages)

# Sample Korean manufacturing data
SAMPLE_DOCUMENTS = [
    "TAB S10 도μž₯ κ³΅μ •μ˜ μˆ˜μœ¨μ€ ν˜„μž¬ 95.2%μž…λ‹ˆλ‹€. λͺ©ν‘œ 수율 94%λ₯Ό μƒνšŒν•˜κ³  있으며, μ§€λ‚œλ‹¬ λŒ€λΉ„ 1.3% ν–₯μƒλ˜μ—ˆμŠ΅λ‹ˆλ‹€.",
    "도μž₯ λΌμΈμ—μ„œ λΆˆλŸ‰λ₯ μ΄ 4.8% λ°œμƒν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. μ£Όμš” λΆˆλŸ‰ 원인은 μ˜¨λ„ 편차(45%)와 μŠ΅λ„ λ³€ν™”(30%)μž…λ‹ˆλ‹€.",
    "S10 λͺ¨λΈμ˜ 전체 생산 μˆ˜μœ¨μ€ 89.5%둜 λͺ©ν‘œμΉ˜ 88%λ₯Ό μƒνšŒν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. μ›”κ°„ μƒμ‚°λŸ‰μ€ 15,000λŒ€μž…λ‹ˆλ‹€.",
    "도μž₯ 라인의 μ˜¨λ„λŠ” 22Β±2℃, μŠ΅λ„λŠ” 45Β±5%둜 μœ μ§€λ˜μ–΄μ•Ό ν•©λ‹ˆλ‹€. ν˜„μž¬ μžλ™ μ œμ–΄ μ‹œμŠ€ν…œμœΌλ‘œ κ΄€λ¦¬λ˜κ³  μžˆμŠ΅λ‹ˆλ‹€.",
    "ν’ˆμ§ˆκ΄€λ¦¬ λΆ€μ„œμ—μ„œλŠ” 맀일 3회 μƒ˜ν”Œλ§ 검사λ₯Ό μ‹€μ‹œν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. 검사 ν•­λͺ©μ€ 색상, 광택, λ‘κ»˜μž…λ‹ˆλ‹€.",
    "예방 보전 κ³„νšμ— 따라 도μž₯ μ„€λΉ„λŠ” μ£Ό 1회 μ •κΈ° 점검을 μ‹€μ‹œν•©λ‹ˆλ‹€. λ‹€μŒ μ •κΈ° 보전은 λ‹€μŒ μ£Ό ν™”μš”μΌμž…λ‹ˆλ‹€.",
    "μ‹ κ·œ 도μž₯ 재료 적용 ν›„ μ ‘μ°©λ ₯이 15% ν–₯μƒλ˜μ—ˆμŠ΅λ‹ˆλ‹€. λΉ„μš©μ€ 10% μ¦κ°€ν–ˆμ§€λ§Œ ν’ˆμ§ˆ κ°œμ„  νš¨κ³Όκ°€ ν½λ‹ˆλ‹€.",
    "μž‘μ—…μž κ΅μœ‘μ€ μ›” 2회 μ‹€μ‹œλ˜λ©°, μ•ˆμ „κ΅μœ‘κ³Ό ν’ˆμ§ˆκ΅μœ‘μ„ ν¬ν•¨ν•©λ‹ˆλ‹€. ꡐ윑 참석λ₯ μ€ 98.5%μž…λ‹ˆλ‹€."
]

def run_streamlit_app():
    """Run the Streamlit web application."""
    st.set_page_config(page_title="RAG Demo - Korean QA", page_icon="πŸ€–", layout="wide")
    
    st.title("πŸ€– Korean RAG Demo with SiliconFlow")
    st.markdown("*Retrieval-Augmented Generation for Korean Manufacturing Q&A*")
    
    # Sidebar configuration
    with st.sidebar:
        st.header("βš™οΈ Configuration")
        
        api_key = st.text_input(
            "SiliconFlow API Key",
            value=os.getenv("SILICONFLOW_API_KEY", ""),
            type="password",
            help="Enter your SiliconFlow API key"
        )
        
        use_reranking = st.checkbox("Use Reranking", value=True, help="Use reranking for better results")
        top_k = st.slider("Top K Results", min_value=1, max_value=10, value=3)
        
        if st.button("Initialize RAG System"):
            if not api_key:
                st.error("Please provide SiliconFlow API key")
            else:
                with st.spinner("Initializing RAG system..."):
                    try:
                        rag_system = RAGSystem(api_key)
                        rag_system.add_documents(SAMPLE_DOCUMENTS)
                        st.session_state['rag_system'] = rag_system
                        st.success("RAG system initialized successfully!")
                    except Exception as e:
                        st.error(f"Failed to initialize RAG system: {e}")
    
    # Main interface
    if 'rag_system' in st.session_state:
        st.header("πŸ’¬ Ask Questions")
        
        # Sample questions
        sample_questions = [
            "TAB S10 도μž₯ 곡정 수율이 μ–΄λ–»κ²Œ λ˜λ‚˜μš”?",
            "도μž₯ 라인의 λΆˆλŸ‰λ₯ κ³Ό μ£Όμš” 원인은?",
            "ν’ˆμ§ˆ κ²€μ‚¬λŠ” μ–΄λ–»κ²Œ μ§„ν–‰λ˜λ‚˜μš”?",
            "예방 보전 κ³„νšμ— λŒ€ν•΄ μ•Œλ €μ£Όμ„Έμš”"
        ]
        
        col1, col2 = st.columns([3, 1])
        
        with col1:
            query = st.text_input("μ§ˆλ¬Έμ„ μž…λ ₯ν•˜μ„Έμš”:", placeholder="예: TAB S10 수율이 μ–΄λ–»κ²Œ λ˜λ‚˜μš”?")
        
        with col2:
            st.markdown("**μƒ˜ν”Œ 질문:**")
            for i, sample in enumerate(sample_questions):
                if st.button(f"Q{i+1}", key=f"sample_{i}", help=sample):
                    st.rerun()
        
        if query:
            with st.spinner("Searching and generating answer..."):
                try:
                    result = st.session_state['rag_system'].query(
                        query, 
                        top_k=top_k, 
                        use_reranking=use_reranking
                    )
                    
                    # Display results
                    st.header("πŸ“‹ Answer")
                    st.write(result.answer)
                    
                    st.header("πŸ“„ Relevant Documents")
                    for i, doc in enumerate(result.relevant_documents):
                        with st.expander(f"Document {i+1} - {doc.metadata.get('source', 'Unknown')}"):
                            st.write(doc.content)
                    
                    # Stats
                    st.sidebar.metric("Processing Time", f"{result.processing_time:.2f}s")
                    st.sidebar.metric("Documents Found", len(result.relevant_documents))
                    
                except Exception as e:
                    st.error(f"Query failed: {e}")
    else:
        st.info("πŸ‘ˆ Please initialize the RAG system using the sidebar")
        
        # Show sample documents
        st.header("πŸ“š Sample Documents")
        st.markdown("The system includes these sample manufacturing documents:")
        for i, doc in enumerate(SAMPLE_DOCUMENTS, 1):
            st.markdown(f"**{i}.** {doc}")

def run_cli_demo():
    """Run command line interface demo."""
    print("πŸ€– Korean RAG Demo - CLI Mode")
    print("=" * 50)
    
    # Get API key
    api_key = os.getenv("SILICONFLOW_API_KEY")
    if not api_key:
        api_key = input("Enter your SiliconFlow API key: ")
    
    if not api_key:
        print("❌ API key is required")
        return
    
    try:
        # Initialize RAG system
        print("πŸ”„ Initializing RAG system...")
        rag_system = RAGSystem(api_key)
        rag_system.add_documents(SAMPLE_DOCUMENTS)
        print("βœ… RAG system ready!")
        
        # Interactive loop
        while True:
            print("\n" + "-" * 50)
            query = input("μ§ˆλ¬Έμ„ μž…λ ₯ν•˜μ„Έμš” (μ’…λ£Œν•˜λ €λ©΄ 'quit'): ")
            
            if query.lower() in ['quit', 'exit', 'μ’…λ£Œ']:
                break
            
            if not query.strip():
                continue
            
            print("πŸ” Searching...")
            result = rag_system.query(query, top_k=3, use_reranking=True)
            
            print(f"\nπŸ“‹ λ‹΅λ³€:")
            print(result.answer)
            
            print(f"\nπŸ“„ κ΄€λ ¨ λ¬Έμ„œλ“€:")
            for i, doc in enumerate(result.relevant_documents, 1):
                print(f"{i}. {doc.content}")
            
            print(f"\n⏱️  처리 μ‹œκ°„: {result.processing_time:.2f}초")
    
    except Exception as e:
        print(f"❌ Error: {e}")

if __name__ == "__main__":
    # Check if running in streamlit context
    try:
        # This will raise an exception if not in streamlit context
        st.session_state
        run_streamlit_app()
    except:
        # Run CLI version
        run_cli_demo()