Spaces:
Build error
Build error
| import os | |
| cwd = os.getcwd() | |
| os.environ['PYTORCH_TRANSFORMERS_CACHE'] = os.path.join(cwd, 'huggingface/transformers/') | |
| os.environ['TRANSFORMERS_CACHE'] = os.path.join(cwd, 'huggingface/transformers/') | |
| os.environ['HF_HOME'] = os.path.join(cwd, 'huggingface/') | |
| # import sys | |
| import logging | |
| from json import JSONDecodeError | |
| from pathlib import Path | |
| # import zipfile | |
| import pandas as pd | |
| import streamlit as st | |
| from markdown import markdown | |
| from utils import get_backlink, get_pipelines, query, send_feedback, upload_doc | |
| # Adjust to a question that you would like users to see in the search bar when they load the UI: | |
| DEFAULT_QUESTION_AT_STARTUP = os.getenv( | |
| "DEFAULT_QUESTION_AT_STARTUP", "How to get TPS?") | |
| DEFAULT_ANSWER_AT_STARTUP = os.getenv( | |
| "DEFAULT_ANSWER_AT_STARTUP", "You must file a Form I-765") | |
| # Sliders | |
| DEFAULT_DOCS_FROM_RETRIEVER = int( | |
| os.getenv("DEFAULT_DOCS_FROM_RETRIEVER", "5")) | |
| DEFAULT_NUMBER_OF_ANSWERS = int(os.getenv("DEFAULT_NUMBER_OF_ANSWERS", "1")) | |
| # Whether the file upload should be enabled or not | |
| DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD", "True")) | |
| LANG_MAP = {"English": "English", "Ukrainian": "Ukrainian", "russian": "russian"} | |
| pipelines = get_pipelines() | |
| def set_state_if_absent(key, value): | |
| if key not in st.session_state: | |
| st.session_state[key] = value | |
| def main(): | |
| st.set_page_config(page_title="AI advisor") | |
| # Persistent state | |
| set_state_if_absent("question", DEFAULT_QUESTION_AT_STARTUP) | |
| set_state_if_absent("answer", DEFAULT_ANSWER_AT_STARTUP) | |
| set_state_if_absent("results", None) | |
| set_state_if_absent("raw_json", None) | |
| set_state_if_absent("random_question_requested", False) | |
| # Small callback to reset the interface in case the text of the question changes | |
| def reset_results(*args): | |
| st.session_state.answer = None | |
| st.session_state.results = None | |
| st.session_state.raw_json = None | |
| # Title | |
| st.write("# AI Immigration advisor") | |
| # Sidebar | |
| st.sidebar.header("Options") | |
| language = st.sidebar.selectbox( | |
| "Select language: ", ("English", "Ukrainian", "Spanish", "French", "Italian", "Arabic", "Hindi", "Portuguese", "Mandarin Chinese", "Japanese", "russian")) | |
| debug = False | |
| debug = False | |
| # debug = st.sidebar.checkbox("Show debug info") | |
| if debug: | |
| top_k_reader = st.sidebar.slider( | |
| "Max. number of answers", | |
| min_value=1, | |
| max_value=100, | |
| value=DEFAULT_NUMBER_OF_ANSWERS, | |
| step=1, | |
| on_change=reset_results, | |
| ) | |
| top_k_retriever = st.sidebar.slider( | |
| "Max. number of documents from retriever", | |
| min_value=1, | |
| max_value=100, | |
| value=DEFAULT_DOCS_FROM_RETRIEVER, | |
| step=1, | |
| on_change=reset_results, | |
| ) | |
| else: | |
| top_k_reader = DEFAULT_NUMBER_OF_ANSWERS | |
| top_k_retriever = DEFAULT_DOCS_FROM_RETRIEVER | |
| # File upload block | |
| if not DISABLE_FILE_UPLOAD: | |
| st.sidebar.write("## File Upload:") | |
| data_files = st.sidebar.file_uploader( | |
| "", type=["pdf", "txt", "docx"], accept_multiple_files=True) | |
| for data_file in data_files: | |
| # Upload file | |
| if data_file: | |
| raw_json = upload_doc(data_file) | |
| st.sidebar.write(str(data_file.name) + " β ") | |
| if debug: | |
| st.subheader("REST API JSON response") | |
| st.sidebar.write(raw_json) | |
| # st.sidebar.markdown( | |
| # f""" | |
| # <style> | |
| # a {{ | |
| # text-decoration: none; | |
| # }} | |
| # .haystack-footer {{ | |
| # text-align: center; | |
| # }} | |
| # .haystack-footer h4 {{ | |
| # margin: 0.1rem; | |
| # padding:0; | |
| # }} | |
| # footer {{ | |
| # opacity: 0; | |
| # }} | |
| # </style> | |
| # <div class="haystack-footer"> | |
| # <hr /> | |
| # <h4>Debug parameters</h4> | |
| # <small>Data crawled from <a href="https://www.uscis.gov">USCIS</a></small></div> | |
| # """, | |
| # unsafe_allow_html=True, | |
| # ) | |
| # Search bar | |
| question = st.text_input( | |
| "", value=st.session_state.question, max_chars=100, on_change=reset_results) | |
| col1, col2 = st.columns(2) | |
| col1.markdown( | |
| "<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) | |
| col2.markdown( | |
| "<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) | |
| # Run button | |
| run_pressed = col1.button("Run") | |
| run_query = ( | |
| run_pressed or question != st.session_state.question | |
| ) and not st.session_state.random_question_requested | |
| # Get results for query | |
| if run_query and question: | |
| reset_results() | |
| st.session_state.question = question | |
| with st.spinner("π§ Performing neural search on documents... \n "): | |
| try: | |
| st.session_state.results, st.session_state.raw_json = query( | |
| pipelines, question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever, language=language | |
| ) | |
| except JSONDecodeError as je: | |
| st.error( | |
| "π An error occurred reading the results. Is the document store working?") | |
| return | |
| except Exception as e: | |
| logging.exception(e) | |
| if "The server is busy processing requests" in str(e) or "503" in str(e): | |
| st.error( | |
| "π§βπΎ All our workers are busy! Try again later.") | |
| else: | |
| st.error( | |
| "π An error occurred during the request.") | |
| return | |
| if st.session_state.results: | |
| st.write("## Results:") | |
| for count, result in enumerate(st.session_state.results): | |
| if result["answer"]: | |
| answer, context = result["answer"], result["context"] | |
| start_idx = context.find(answer) | |
| end_idx = start_idx + len(answer) | |
| # Hack due to this bug: https://github.com/streamlit/streamlit/issues/3190 | |
| st.write( | |
| markdown(f"**Answer:** {answer}"), unsafe_allow_html=True) | |
| # st.write( | |
| # markdown(context[:start_idx] + str(annotation(answer, "ANSWER", "#8ef")) + context[end_idx:]), | |
| # unsafe_allow_html=True, | |
| # ) | |
| source = "" | |
| url, title = get_backlink(result) | |
| if url and title: | |
| source = f"[{result['document']['meta']['title']}]({result['document']['meta']['url']})" | |
| else: | |
| source = f"{result['source']}" | |
| st.markdown(f"**Source:** {source}") | |
| else: | |
| st.info( | |
| "π€ Unsure whether any of the documents contain an answer to your question. Try to reformulate it!" | |
| ) | |
| st.write("___") | |
| if debug: | |
| st.subheader("REST API JSON response") | |
| st.write(st.session_state.raw_json) | |
| main() | |