Spaces:
Running
Running
| import streamlit as st | |
| from Rag import launch_depression_assistant, depression_assistant | |
| from openai import OpenAI | |
| from together import Together | |
| import time | |
| import os | |
| from dotenv import load_dotenv | |
| from feedback_utils import FeedbackManager | |
| load_dotenv() | |
| def load_embedding_model_cached(embedder_name): | |
| print(f"π Loading cached embedding model: {embedder_name}") | |
| launch_depression_assistant(embedder_name=embedder_name, designated_client=None) | |
| print(f"β Cached embedding model loaded successfully: {embedder_name}") | |
| return { | |
| "embedder_name": embedder_name, | |
| "status": "loaded", | |
| "timestamp": time.time() | |
| } | |
| def get_llm_client(client_type, api_key): | |
| if client_type == "together": | |
| return Together(api_key=api_key) | |
| elif client_type == "nvidia": | |
| return OpenAI( | |
| base_url="https://integrate.api.nvidia.com/v1", | |
| api_key=api_key, | |
| ) | |
| return None | |
| def get_current_model_info(): | |
| if "cached_model_info" in st.session_state and st.session_state.cached_model_info: | |
| return st.session_state.cached_model_info | |
| return None | |
| def force_reload_model(): | |
| st.cache_resource.clear() | |
| if "cached_model_info" in st.session_state: | |
| del st.session_state.cached_model_info | |
| if "user_llm_client" in st.session_state: | |
| del st.session_state.user_llm_client | |
| # Initialize feedback manager | |
| if "feedback_manager" not in st.session_state: | |
| st.session_state.feedback_manager = FeedbackManager() | |
| st.set_page_config( | |
| page_title="Depression Assistant Chatbot", | |
| page_icon=":robot_face:", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| model_options = [ | |
| "Qwen/Qwen3-Embedding-0.6B", | |
| "jinaai/jina-embeddings-v3", | |
| # "BAAI/bge-large-en-v1.5", | |
| "BAAI/bge-small-en-v1.5", | |
| # "BAAI/bge-base-en-v1.5", | |
| "sentence-transformers/all-mpnet-base-v2", | |
| # "Other" | |
| ] | |
| # --- Sidebar --- | |
| st.sidebar.title("Settings") | |
| with st.sidebar: | |
| st.subheader("Model Selection") | |
| embedder_name = st.sidebar.selectbox( | |
| "Select embedder model", | |
| model_options, | |
| index=0 | |
| ) | |
| if embedder_name == "Other": | |
| embedder_name = st.sidebar.text_input('Enter the embedder model name') | |
| current_info = get_current_model_info() | |
| if current_info and current_info["embedder_name"] == embedder_name: | |
| st.success(f"β Current model: {embedder_name}") | |
| st.caption(f"Loaded at: {time.strftime('%H:%M:%S', time.localtime(current_info['timestamp']))}") | |
| else: | |
| if embedder_name: | |
| with st.spinner(f"Loading embedding model: {embedder_name}..."): | |
| try: | |
| model_info = load_embedding_model_cached(embedder_name=embedder_name) | |
| st.session_state.cached_model_info = model_info | |
| st.success(f"β Model {embedder_name} loaded successfully!") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"β Failed to load model: {str(e)}") | |
| st.session_state.cached_model_info = None | |
| if st.button("π Force Reload Model", help="Clear cache and reload the model"): | |
| force_reload_model() | |
| st.rerun() | |
| selected_model = st.sidebar.selectbox( | |
| 'Choose a model for generation', | |
| [ | |
| "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", | |
| "deepseek-ai/deepseek-r1", | |
| "meta/llama-3.3-70b-instruct" | |
| ], | |
| key='selected_model' | |
| ) | |
| if selected_model in ["deepseek-ai/deepseek-r1", "meta/llama-3.3-70b-instruct"]: | |
| max_length_default = 1000 | |
| client_type = "nvidia" | |
| api_key = os.getenv("NVIDIA_API_KEY") | |
| else: | |
| max_length_default = 500 | |
| client_type = "together" | |
| api_key = os.getenv("TOGETHER_API_KEY") | |
| client_key = f"{client_type}_{selected_model}" | |
| if "user_llm_client" not in st.session_state or st.session_state.get("client_key") != client_key: | |
| if api_key: | |
| st.session_state.user_llm_client = get_llm_client(client_type, api_key) | |
| st.session_state.client_key = client_key | |
| st.sidebar.success(f"β LLM Client: {client_type.upper()}") | |
| else: | |
| st.session_state.user_llm_client = None | |
| st.sidebar.error(f"β Missing API key for {client_type.upper()}") | |
| else: | |
| st.sidebar.info(f"π± LLM Client: {client_type.upper()} (Ready)") | |
| temperature = st.sidebar.slider('Temperature', min_value=0.01, max_value=1.0, value=0.05, step=0.01) | |
| top_p = st.sidebar.slider('Top P', min_value=0.01, max_value=1.0, value=0.9, step=0.01) | |
| max_length = st.sidebar.slider('Max Length', min_value=100, max_value=1000, value=max_length_default, step=10) | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("**Current Configuration:**") | |
| st.sidebar.caption(f"Embedder: {embedder_name}") | |
| st.sidebar.caption(f"LLM: {selected_model}") | |
| st.sidebar.caption(f"Client: {client_type.upper()}") | |
| st.sidebar.caption(f"Session ID: {st.session_state.get('client_key', 'None')[:20]}...") | |
| # Google Sheets status | |
| if st.session_state.feedback_manager.is_connected(): | |
| st.sidebar.success("π Google Sheets: Connected") | |
| # Check if conversation logging is available | |
| if hasattr(st.session_state.feedback_manager, 'conversation_worksheet') and st.session_state.feedback_manager.conversation_worksheet: | |
| st.sidebar.success("π Conversation Logging: Active") | |
| else: | |
| st.sidebar.warning("π Conversation Logging: Not Available") | |
| else: | |
| st.sidebar.error("π Google Sheets: Not Connected") | |
| # Show title and description | |
| st.title("π¬ Depression Assistant Chatbot") | |
| if not get_current_model_info(): | |
| st.warning("β οΈ Please select and load an embedding model from the sidebar first.") | |
| st.stop() | |
| # Initialize chat history | |
| if "messages" not in st.session_state.keys(): | |
| st.session_state.messages = [{ | |
| "role": "assistant", | |
| "content": "Welcome to a prototype of the open-source and open-weight CANMAT/MDD 2023 depression Guideline chatbot. Please try asking it questions that can be answered by the guidelines. Improvements are ongoing - the visual aspect will change substantially soon. Please let John-Jose know any feedback at [johnjose.nunez@ubc.ca](johnjose.nunez@ubc.ca). Thanks!" | |
| }] | |
| # Initialize sources tracking | |
| if "message_sources" not in st.session_state: | |
| st.session_state.message_sources = {} | |
| # Initialize feedback tracking | |
| if "feedback_submitted" not in st.session_state: | |
| st.session_state.feedback_submitted = set() | |
| # Display chat messages from history on app rerun | |
| for idx, message in enumerate(st.session_state.messages): | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Add feedback section for assistant messages (except the first welcome message) | |
| if message["role"] == "assistant" and idx > 0: | |
| # Add sources expander for this message | |
| if idx in st.session_state.message_sources: | |
| sources_expander = st.expander("π See Sources") | |
| with sources_expander: | |
| results = st.session_state.message_sources[idx] | |
| if results: | |
| for i, result in enumerate(results): | |
| st.markdown(f"**Source {i+1}:** **Similarity:** {result.get('similarity', 'N/A')}") | |
| st.write(f"**TEXT:** {result['text']}") | |
| st.markdown(f"**Section:** {result['section']}") | |
| st.markdown("---") | |
| else: | |
| st.markdown("No relevant sources found.") | |
| # Check if feedback was already submitted for this message | |
| feedback_key = f"feedback_submitted_{idx}" | |
| if feedback_key not in st.session_state.feedback_submitted: | |
| # Put feedback in an expander | |
| feedback_expander = st.expander("π Provide Feedback") | |
| with feedback_expander: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("**β Rating Questions:**") | |
| source_rating = st.selectbox( | |
| "Please rate the answer provided. Higher ratings indicate better quality in the answer.", | |
| options=[None, 1, 2, 3, 4, 5], | |
| format_func=lambda x: "Select rating..." if x is None else f"{x} {'β' * x}", | |
| key=f"source_rating_{idx}" | |
| ) | |
| answer_rating = st.selectbox( | |
| "Please rate the quality of the data source provided - is it sufficient to answer the question? Higher ratings indicate better quality in the source.", | |
| options=[None, 1, 2, 3, 4, 5], | |
| format_func=lambda x: "Select rating..." if x is None else f"{x} {'β' * x}", | |
| key=f"answer_rating_{idx}" | |
| ) | |
| with col2: | |
| # Text feedback questions | |
| st.markdown("**π Detailed Feedback Questions:**") | |
| feedback_q1 = st.text_area( | |
| "Why is the answer wrong? Does it miss any key information?", | |
| placeholder="Please describe any mistakes or missing information in the response...", | |
| key=f"feedback_q1_{idx}", | |
| height=80 | |
| ) | |
| # Submit feedback button | |
| if st.button("Submit Feedback", key=f"submit_{idx}"): | |
| if source_rating is not None or answer_rating is not None: | |
| # Get current model parameters | |
| current_embedder = get_current_model_info()["embedder_name"] if get_current_model_info() else "Unknown" | |
| # Get the corresponding user query (should be the previous message) | |
| user_query = st.session_state.messages[idx-1]["content"] if idx > 0 else "Unknown" | |
| # Save feedback | |
| success = st.session_state.feedback_manager.save_feedback( | |
| user_query=user_query, | |
| ai_response=message["content"], | |
| source_rating=source_rating, | |
| answer_rating=answer_rating, | |
| feedback_q1=feedback_q1 or "", | |
| embedder_model=current_embedder, | |
| llm_model=getattr(st.session_state, 'last_model_used', 'Unknown'), | |
| temperature=getattr(st.session_state, 'last_temperature_used', 0), | |
| top_p=getattr(st.session_state, 'last_top_p_used', 0), | |
| max_length=getattr(st.session_state, 'last_max_length_used', 0) | |
| ) | |
| if success: | |
| st.success("β Thank you for your feedback!") | |
| st.session_state.feedback_submitted.add(feedback_key) | |
| st.rerun() | |
| else: | |
| st.error("β Failed to save feedback. Please check your Google Sheets configuration.") | |
| else: | |
| st.warning("β οΈ Please select at least one rating before submitting feedback.") | |
| else: | |
| st.success("β Feedback already submitted for this response") | |
| # User input | |
| if user_input := st.chat_input("Ask me questions about the CANMAT depression guideline!"): | |
| if not get_current_model_info(): | |
| st.error("β Please load an embedding model first from the sidebar.") | |
| st.stop() | |
| if not st.session_state.get("user_llm_client"): | |
| st.error("β LLM client not available. Please check your API keys.") | |
| st.stop() | |
| st.chat_message("user").markdown(user_input) | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # Store current model parameters for feedback | |
| st.session_state.last_model_used = selected_model | |
| st.session_state.last_temperature_used = temperature | |
| st.session_state.last_top_p_used = top_p | |
| st.session_state.last_max_length_used = max_length | |
| # ===== latest 10 histories(5 round) ===== | |
| history = st.session_state.messages[:-1][-10:] | |
| placeholder = st.chat_message("assistant").empty() | |
| collected = "" | |
| try: | |
| t0 = time.perf_counter() | |
| import Rag | |
| original_client = Rag.llm_client | |
| Rag.llm_client = st.session_state.user_llm_client | |
| try: | |
| results, response = depression_assistant( | |
| user_input, | |
| model_name=selected_model, | |
| max_tokens=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| stream_flag=True, | |
| chat_history=history | |
| ) | |
| for chunk in response: | |
| collected += chunk | |
| placeholder.markdown(collected) | |
| finally: | |
| Rag.llm_client = original_client | |
| t1 = time.perf_counter() | |
| response_time = t1 - t0 | |
| print(f"[Time] Retriever + Generator takes: {response_time:.2f} seconds in total.") | |
| print(f"============== Finish R-A-Generation for Current Query {user_input} ==============") | |
| # Save the response and sources | |
| st.session_state.messages.append({"role": "assistant", "content": collected}) | |
| # Save sources for this message (will be at index len(messages)-1) | |
| message_idx = len(st.session_state.messages) - 1 | |
| st.session_state.message_sources[message_idx] = results | |
| # Log conversation to Google Sheets | |
| try: | |
| current_embedder = get_current_model_info()["embedder_name"] if get_current_model_info() else "Unknown" | |
| session_id = st.session_state.get('client_key', 'Unknown') | |
| st.session_state.feedback_manager.log_conversation( | |
| session_id=session_id, | |
| user_query=user_input, | |
| ai_response=collected, | |
| embedder_model=current_embedder, | |
| llm_model=selected_model, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_length=max_length, | |
| response_time=response_time | |
| ) | |
| except Exception as log_error: | |
| print(f"Warning: Failed to log conversation to Google Sheets: {log_error}") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"β Error generating response: {str(e)}") | |
| print(f"Error in main loop: {e}") | |
| import traceback | |
| traceback.print_exc() |