Spaces:
Runtime error
Runtime error
shaocongma
Add a generator wrapper using configuration file. Edit the logic of searching references. Add Gradio UI for testing Knowledge database.
94dc00e
| from utils.knowledge import Knowledge | |
| from langchain.vectorstores import FAISS | |
| from utils.file_operations import list_folders | |
| from huggingface_hub import snapshot_download | |
| import gradio as gr | |
| import os | |
| import json | |
| from models import EMBEDDINGS | |
| from utils.gpt_interaction import GPTModel | |
| from utils.prompts import SYSTEM | |
| import openai | |
| llm = GPTModel(model="gpt-3.5-turbo") | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| REPO_ID = os.getenv("KDB_REPO") | |
| if HF_TOKEN is not None and REPO_ID is not None: | |
| snapshot_download(REPO_ID, repo_type="dataset", local_dir="knowledge_databases/", | |
| local_dir_use_symlinks=False, token=HF_TOKEN) | |
| ALL_KDB = ["(None)"] + list_folders("knowledge_databases") | |
| ANNOUNCEMENT = """ | |
| # Evaluate the quality of retrieved date from the FAISS database | |
| Use this space test the performance of some pre-constructed vector databases hosted at `shaocongma/kdb`. To use this space for your own FAISS database, follow this instruction: | |
| 1. Duplicate this space. | |
| 2. Add the secret key `HF_TOKEN` with your own Huggingface User Access Token. | |
| 3. Create a Huggingface Dataset. Put your FAISS database to it. | |
| 4. Add the secret key `REPO_ID` as your dataset's address. | |
| """ | |
| AUTODRAFT = """ | |
| AutoDraft is a GPT-based project to generate an academic paper using the title and contributions. When generating specific sections, AutoDraft will query some necessary backgrounds in related fields from the pre-constructed vector database. | |
| """ | |
| def query_from_kdb(input, kdb, query_counts): | |
| if kdb == "(None)": | |
| return {"knowledge_database": "(None)", "input": input, "output": ""}, "" | |
| db_path = f"knowledge_databases/{kdb}" | |
| db_config_path = os.path.join(db_path, "db_meta.json") | |
| db_index_path = os.path.join(db_path, "faiss_index") | |
| if os.path.isdir(db_path): | |
| # load configuration file | |
| with open(db_config_path, "r", encoding="utf-8") as f: | |
| db_config = json.load(f) | |
| model_name = db_config["embedding_model"] | |
| embeddings = EMBEDDINGS[model_name] | |
| db = FAISS.load_local(db_index_path, embeddings) | |
| knowledge = Knowledge(db=db) | |
| knowledge.collect_knowledge({input: query_counts}, max_query=query_counts) | |
| domain_knowledge = knowledge.to_json() | |
| else: | |
| raise RuntimeError(f"Failed to query from FAISS.") | |
| return domain_knowledge, "" | |
| def query_from_kdb_llm(title, contributions, kdb, query_counts): | |
| if kdb == "(None)": | |
| return {"knowledge_database": "(None)", "title": title, "contributions": contributions, "output": ""}, "", {} | |
| db_path = f"knowledge_databases/{kdb}" | |
| db_config_path = os.path.join(db_path, "db_meta.json") | |
| db_index_path = os.path.join(db_path, "faiss_index") | |
| if os.path.isdir(db_path): | |
| # load configuration file | |
| with open(db_config_path, "r", encoding="utf-8") as f: | |
| db_config = json.load(f) | |
| model_name = db_config["embedding_model"] | |
| embeddings = EMBEDDINGS[model_name] | |
| db = FAISS.load_local(db_index_path, embeddings) | |
| knowledge = Knowledge(db=db) | |
| prompts = f"Title: {title}\n Contributions: {contributions}" | |
| preliminaries_kw, _ = llm(systems=SYSTEM["preliminaries"], prompts=prompts, return_json=True) | |
| knowledge.collect_knowledge(preliminaries_kw, max_query=query_counts) | |
| domain_knowledge = knowledge.to_json() | |
| else: | |
| raise RuntimeError(f"Failed to query from FAISS.") | |
| return domain_knowledge, "", preliminaries_kw | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(ANNOUNCEMENT) | |
| kdb_dropdown = gr.Dropdown(choices=ALL_KDB, value="(None)", label="Knowledge Databases", | |
| info="Pre-defined knowledge databases utilized to aid in the generation of academic writing content. " | |
| "Hosted at `shaocongma/kdb`.") | |
| with gr.Tab("User's Input"): | |
| user_input = gr.Textbox(label="Input", info="Input anything you like to test what will be retrived from the vector database.") | |
| with gr.Row(): | |
| button_clear = gr.Button("Clear") | |
| button_retrieval = gr.Button("Retrieve", variant="primary") | |
| with gr.Tab("AutoDraft"): | |
| gr.Markdown(AUTODRAFT) | |
| title_input = gr.Textbox(label="Title") | |
| contribution_input = gr.Textbox(label="Contributions", lines=5) | |
| with gr.Row(): | |
| button_clear_2 = gr.Button("Clear") | |
| button_retrieval_2 = gr.Button("Retrieve", variant="primary") | |
| with gr.Accordion("Advanced Setting", open=False): | |
| query_counts_slider = gr.Slider(minimum=1, maximum=50, value=10, step=1, | |
| interactive=True, label="QUERY_COUNTS", | |
| info="How many contents will be retrieved from the vector database.") | |
| with gr.Column(): | |
| retrieval_output = gr.JSON(label="Output") | |
| llm_kws = gr.JSON(label="Keywords generated by LLM") | |
| button_retrieval.click(fn=query_from_kdb, | |
| inputs=[user_input, kdb_dropdown, query_counts_slider], | |
| outputs=[retrieval_output, user_input]) | |
| button_retrieval_2.click(fn=query_from_kdb_llm, | |
| inputs=[title_input, contribution_input, kdb_dropdown, query_counts_slider], | |
| outputs=[retrieval_output, user_input, llm_kws]) | |
| demo.queue(concurrency_count=1, max_size=5, api_open=False) | |
| demo.launch(show_error=True) | |