Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,8 +11,8 @@ from langchain.text_splitter import (
|
|
| 11 |
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
|
| 12 |
from langchain_community.document_loaders import PyPDFLoader
|
| 13 |
from langchain.chains import ConversationalRetrievalChain
|
| 14 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 15 |
-
from
|
| 16 |
from langchain.memory import ConversationBufferMemory
|
| 17 |
from sentence_transformers import SentenceTransformer, util
|
| 18 |
import torch
|
|
@@ -48,53 +48,46 @@ class RAGEvaluator:
|
|
| 48 |
self.test_samples = []
|
| 49 |
|
| 50 |
def load_dataset(self, dataset_name: str, num_samples: int = 10):
|
| 51 |
-
"""Load
|
| 52 |
try:
|
| 53 |
if dataset_name == "squad":
|
| 54 |
dataset = load_dataset("squad_v2", split="validation")
|
| 55 |
-
# Select diverse questions
|
| 56 |
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
| 57 |
|
| 58 |
self.test_samples = []
|
| 59 |
for sample in samples:
|
| 60 |
-
#
|
| 61 |
-
|
|
|
|
| 62 |
self.test_samples.append({
|
| 63 |
"question": sample["question"],
|
| 64 |
-
"ground_truth":
|
| 65 |
"context": sample["context"]
|
| 66 |
})
|
| 67 |
|
| 68 |
elif dataset_name == "msmarco":
|
| 69 |
-
dataset = load_dataset("ms_marco", "v2.1", split="
|
| 70 |
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
| 71 |
|
| 72 |
self.test_samples = []
|
| 73 |
for sample in samples:
|
| 74 |
-
# Check
|
| 75 |
-
if sample.get("answers") and sample["answers"]:
|
| 76 |
self.test_samples.append({
|
| 77 |
"question": sample["query"],
|
| 78 |
"ground_truth": sample["answers"][0],
|
| 79 |
-
"context": sample["passages"][
|
| 80 |
-
if isinstance(sample["passages"], list)
|
| 81 |
-
else sample["passages"]["passage_text"][0]
|
| 82 |
})
|
| 83 |
|
| 84 |
self.current_dataset = dataset_name
|
| 85 |
-
|
| 86 |
-
# Return dataset info
|
| 87 |
return {
|
| 88 |
"dataset": dataset_name,
|
| 89 |
-
"
|
| 90 |
-
"
|
| 91 |
-
"status": "success"
|
| 92 |
}
|
| 93 |
|
| 94 |
except Exception as e:
|
| 95 |
print(f"Error loading dataset: {str(e)}")
|
| 96 |
return {
|
| 97 |
-
"dataset": dataset_name,
|
| 98 |
"error": str(e),
|
| 99 |
"status": "failed"
|
| 100 |
}
|
|
@@ -205,36 +198,58 @@ def create_db(splits, db_choice: str = "faiss"):
|
|
| 205 |
return db_creators[db_choice]()
|
| 206 |
|
| 207 |
def initialize_database(list_file_obj, splitting_strategy, chunk_size, db_choice, progress=gr.Progress()):
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
def conversation(qa_chain, message, history):
|
| 240 |
"""Fixed conversation function returning all required outputs"""
|
|
@@ -424,12 +439,26 @@ def demo():
|
|
| 424 |
initialize_database,
|
| 425 |
inputs=[document, splitting_strategy, chunk_size, db_choice],
|
| 426 |
outputs=[vector_db, db_progress]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
)
|
| 428 |
|
| 429 |
init_llm_btn.click(
|
| 430 |
initialize_llmchain,
|
| 431 |
inputs=[llm_choice, temperature, max_tokens, top_k, vector_db],
|
| 432 |
outputs=[qa_chain, llm_progress]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
)
|
| 434 |
|
| 435 |
msg.submit(
|
|
|
|
| 11 |
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
|
| 12 |
from langchain_community.document_loaders import PyPDFLoader
|
| 13 |
from langchain.chains import ConversationalRetrievalChain
|
| 14 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 15 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
| 16 |
from langchain.memory import ConversationBufferMemory
|
| 17 |
from sentence_transformers import SentenceTransformer, util
|
| 18 |
import torch
|
|
|
|
| 48 |
self.test_samples = []
|
| 49 |
|
| 50 |
def load_dataset(self, dataset_name: str, num_samples: int = 10):
|
| 51 |
+
"""Load dataset with proper error handling"""
|
| 52 |
try:
|
| 53 |
if dataset_name == "squad":
|
| 54 |
dataset = load_dataset("squad_v2", split="validation")
|
|
|
|
| 55 |
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
| 56 |
|
| 57 |
self.test_samples = []
|
| 58 |
for sample in samples:
|
| 59 |
+
# Handle SQuAD format
|
| 60 |
+
answers = sample["answers"]
|
| 61 |
+
if answers["text"]: # Check if there are answers
|
| 62 |
self.test_samples.append({
|
| 63 |
"question": sample["question"],
|
| 64 |
+
"ground_truth": answers["text"][0],
|
| 65 |
"context": sample["context"]
|
| 66 |
})
|
| 67 |
|
| 68 |
elif dataset_name == "msmarco":
|
| 69 |
+
dataset = load_dataset("ms_marco", "v2.1", split="test") # Changed from dev to test
|
| 70 |
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
| 71 |
|
| 72 |
self.test_samples = []
|
| 73 |
for sample in samples:
|
| 74 |
+
if sample["answers"]: # Check if answers exist
|
|
|
|
| 75 |
self.test_samples.append({
|
| 76 |
"question": sample["query"],
|
| 77 |
"ground_truth": sample["answers"][0],
|
| 78 |
+
"context": sample["passages"]["passage_text"][0]
|
|
|
|
|
|
|
| 79 |
})
|
| 80 |
|
| 81 |
self.current_dataset = dataset_name
|
|
|
|
|
|
|
| 82 |
return {
|
| 83 |
"dataset": dataset_name,
|
| 84 |
+
"samples_loaded": len(self.test_samples),
|
| 85 |
+
"example_questions": [s["question"] for s in self.test_samples[:3]]
|
|
|
|
| 86 |
}
|
| 87 |
|
| 88 |
except Exception as e:
|
| 89 |
print(f"Error loading dataset: {str(e)}")
|
| 90 |
return {
|
|
|
|
| 91 |
"error": str(e),
|
| 92 |
"status": "failed"
|
| 93 |
}
|
|
|
|
| 198 |
return db_creators[db_choice]()
|
| 199 |
|
| 200 |
def initialize_database(list_file_obj, splitting_strategy, chunk_size, db_choice, progress=gr.Progress()):
|
| 201 |
+
"""Initialize vector database with error handling"""
|
| 202 |
+
try:
|
| 203 |
+
if not list_file_obj:
|
| 204 |
+
return None, "No files uploaded. Please upload PDF documents first."
|
| 205 |
+
|
| 206 |
+
list_file_path = [x.name for x in list_file_obj if x is not None]
|
| 207 |
+
if not list_file_path:
|
| 208 |
+
return None, "No valid files found. Please upload PDF documents."
|
| 209 |
+
|
| 210 |
+
doc_splits = load_doc(list_file_path, splitting_strategy, chunk_size)
|
| 211 |
+
if not doc_splits:
|
| 212 |
+
return None, "No content extracted from documents."
|
| 213 |
+
|
| 214 |
+
vector_db = create_db(doc_splits, db_choice)
|
| 215 |
+
return vector_db, f"Database created successfully using {splitting_strategy} splitting and {db_choice} vector database!"
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
return None, f"Error creating database: {str(e)}"
|
| 219 |
|
| 220 |
def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
| 221 |
+
"""Initialize LLM chain with error handling"""
|
| 222 |
+
try:
|
| 223 |
+
if vector_db is None:
|
| 224 |
+
return None, "Please create vector database first."
|
| 225 |
+
|
| 226 |
+
llm_model = list_llm[llm_choice]
|
| 227 |
+
|
| 228 |
+
llm = HuggingFaceEndpoint(
|
| 229 |
+
repo_id=llm_model,
|
| 230 |
+
huggingfacehub_api_token=api_token,
|
| 231 |
+
temperature=temperature,
|
| 232 |
+
max_new_tokens=max_tokens,
|
| 233 |
+
top_k=top_k
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
memory = ConversationBufferMemory(
|
| 237 |
+
memory_key="chat_history",
|
| 238 |
+
output_key='answer',
|
| 239 |
+
return_messages=True
|
| 240 |
+
)
|
| 241 |
|
| 242 |
+
retriever = vector_db.as_retriever()
|
| 243 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 244 |
+
llm,
|
| 245 |
+
retriever=retriever,
|
| 246 |
+
memory=memory,
|
| 247 |
+
return_source_documents=True
|
| 248 |
+
)
|
| 249 |
+
return qa_chain, "LLM initialized successfully!"
|
| 250 |
+
|
| 251 |
+
except Exception as e:
|
| 252 |
+
return None, f"Error initializing LLM: {str(e)}"
|
| 253 |
|
| 254 |
def conversation(qa_chain, message, history):
|
| 255 |
"""Fixed conversation function returning all required outputs"""
|
|
|
|
| 439 |
initialize_database,
|
| 440 |
inputs=[document, splitting_strategy, chunk_size, db_choice],
|
| 441 |
outputs=[vector_db, db_progress]
|
| 442 |
+
).then(
|
| 443 |
+
lambda x: gr.update(interactive=True) if x[0] is not None else gr.update(interactive=False),
|
| 444 |
+
inputs=[vector_db],
|
| 445 |
+
outputs=[init_llm_btn]
|
| 446 |
)
|
| 447 |
|
| 448 |
init_llm_btn.click(
|
| 449 |
initialize_llmchain,
|
| 450 |
inputs=[llm_choice, temperature, max_tokens, top_k, vector_db],
|
| 451 |
outputs=[qa_chain, llm_progress]
|
| 452 |
+
).then(
|
| 453 |
+
lambda x: gr.update(interactive=True) if x[0] is not None else gr.update(interactive=False),
|
| 454 |
+
inputs=[qa_chain],
|
| 455 |
+
outputs=[msg]
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
load_dataset_btn.click(
|
| 459 |
+
lambda x: evaluator.load_dataset(x),
|
| 460 |
+
inputs=[dataset_choice],
|
| 461 |
+
outputs=[dataset_info]
|
| 462 |
)
|
| 463 |
|
| 464 |
msg.submit(
|