zzejiao's picture
change last 4 history(2 round) to last 5 round
78717f7
import streamlit as st
from Rag import launch_depression_assistant, depression_assistant
from openai import OpenAI
from together import Together
import time
import os
from dotenv import load_dotenv
from feedback_utils import FeedbackManager
load_dotenv()
@st.cache_resource
def load_embedding_model_cached(embedder_name):
print(f"πŸ”„ Loading cached embedding model: {embedder_name}")
launch_depression_assistant(embedder_name=embedder_name, designated_client=None)
print(f"βœ… Cached embedding model loaded successfully: {embedder_name}")
return {
"embedder_name": embedder_name,
"status": "loaded",
"timestamp": time.time()
}
def get_llm_client(client_type, api_key):
if client_type == "together":
return Together(api_key=api_key)
elif client_type == "nvidia":
return OpenAI(
base_url="https://integrate.api.nvidia.com/v1",
api_key=api_key,
)
return None
def get_current_model_info():
if "cached_model_info" in st.session_state and st.session_state.cached_model_info:
return st.session_state.cached_model_info
return None
def force_reload_model():
st.cache_resource.clear()
if "cached_model_info" in st.session_state:
del st.session_state.cached_model_info
if "user_llm_client" in st.session_state:
del st.session_state.user_llm_client
# Initialize feedback manager
if "feedback_manager" not in st.session_state:
st.session_state.feedback_manager = FeedbackManager()
st.set_page_config(
page_title="Depression Assistant Chatbot",
page_icon=":robot_face:",
layout="wide",
initial_sidebar_state="expanded"
)
model_options = [
"Qwen/Qwen3-Embedding-0.6B",
"jinaai/jina-embeddings-v3",
# "BAAI/bge-large-en-v1.5",
"BAAI/bge-small-en-v1.5",
# "BAAI/bge-base-en-v1.5",
"sentence-transformers/all-mpnet-base-v2",
# "Other"
]
# --- Sidebar ---
st.sidebar.title("Settings")
with st.sidebar:
st.subheader("Model Selection")
embedder_name = st.sidebar.selectbox(
"Select embedder model",
model_options,
index=0
)
if embedder_name == "Other":
embedder_name = st.sidebar.text_input('Enter the embedder model name')
current_info = get_current_model_info()
if current_info and current_info["embedder_name"] == embedder_name:
st.success(f"βœ… Current model: {embedder_name}")
st.caption(f"Loaded at: {time.strftime('%H:%M:%S', time.localtime(current_info['timestamp']))}")
else:
if embedder_name:
with st.spinner(f"Loading embedding model: {embedder_name}..."):
try:
model_info = load_embedding_model_cached(embedder_name=embedder_name)
st.session_state.cached_model_info = model_info
st.success(f"βœ… Model {embedder_name} loaded successfully!")
st.rerun()
except Exception as e:
st.error(f"❌ Failed to load model: {str(e)}")
st.session_state.cached_model_info = None
if st.button("πŸ”„ Force Reload Model", help="Clear cache and reload the model"):
force_reload_model()
st.rerun()
selected_model = st.sidebar.selectbox(
'Choose a model for generation',
[
"meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
"deepseek-ai/deepseek-r1",
"meta/llama-3.3-70b-instruct"
],
key='selected_model'
)
if selected_model in ["deepseek-ai/deepseek-r1", "meta/llama-3.3-70b-instruct"]:
max_length_default = 1000
client_type = "nvidia"
api_key = os.getenv("NVIDIA_API_KEY")
else:
max_length_default = 500
client_type = "together"
api_key = os.getenv("TOGETHER_API_KEY")
client_key = f"{client_type}_{selected_model}"
if "user_llm_client" not in st.session_state or st.session_state.get("client_key") != client_key:
if api_key:
st.session_state.user_llm_client = get_llm_client(client_type, api_key)
st.session_state.client_key = client_key
st.sidebar.success(f"βœ… LLM Client: {client_type.upper()}")
else:
st.session_state.user_llm_client = None
st.sidebar.error(f"❌ Missing API key for {client_type.upper()}")
else:
st.sidebar.info(f"πŸ“± LLM Client: {client_type.upper()} (Ready)")
temperature = st.sidebar.slider('Temperature', min_value=0.01, max_value=1.0, value=0.05, step=0.01)
top_p = st.sidebar.slider('Top P', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
max_length = st.sidebar.slider('Max Length', min_value=100, max_value=1000, value=max_length_default, step=10)
st.sidebar.markdown("---")
st.sidebar.markdown("**Current Configuration:**")
st.sidebar.caption(f"Embedder: {embedder_name}")
st.sidebar.caption(f"LLM: {selected_model}")
st.sidebar.caption(f"Client: {client_type.upper()}")
st.sidebar.caption(f"Session ID: {st.session_state.get('client_key', 'None')[:20]}...")
# Google Sheets status
if st.session_state.feedback_manager.is_connected():
st.sidebar.success("πŸ“Š Google Sheets: Connected")
# Check if conversation logging is available
if hasattr(st.session_state.feedback_manager, 'conversation_worksheet') and st.session_state.feedback_manager.conversation_worksheet:
st.sidebar.success("πŸ“ Conversation Logging: Active")
else:
st.sidebar.warning("πŸ“ Conversation Logging: Not Available")
else:
st.sidebar.error("πŸ“Š Google Sheets: Not Connected")
# Show title and description
st.title("πŸ’¬ Depression Assistant Chatbot")
if not get_current_model_info():
st.warning("⚠️ Please select and load an embedding model from the sidebar first.")
st.stop()
# Initialize chat history
if "messages" not in st.session_state.keys():
st.session_state.messages = [{
"role": "assistant",
"content": "Welcome to a prototype of the open-source and open-weight CANMAT/MDD 2023 depression Guideline chatbot. Please try asking it questions that can be answered by the guidelines. Improvements are ongoing - the visual aspect will change substantially soon. Please let John-Jose know any feedback at [johnjose.nunez@ubc.ca](johnjose.nunez@ubc.ca). Thanks!"
}]
# Initialize sources tracking
if "message_sources" not in st.session_state:
st.session_state.message_sources = {}
# Initialize feedback tracking
if "feedback_submitted" not in st.session_state:
st.session_state.feedback_submitted = set()
# Display chat messages from history on app rerun
for idx, message in enumerate(st.session_state.messages):
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Add feedback section for assistant messages (except the first welcome message)
if message["role"] == "assistant" and idx > 0:
# Add sources expander for this message
if idx in st.session_state.message_sources:
sources_expander = st.expander("πŸ“š See Sources")
with sources_expander:
results = st.session_state.message_sources[idx]
if results:
for i, result in enumerate(results):
st.markdown(f"**Source {i+1}:** **Similarity:** {result.get('similarity', 'N/A')}")
st.write(f"**TEXT:** {result['text']}")
st.markdown(f"**Section:** {result['section']}")
st.markdown("---")
else:
st.markdown("No relevant sources found.")
# Check if feedback was already submitted for this message
feedback_key = f"feedback_submitted_{idx}"
if feedback_key not in st.session_state.feedback_submitted:
# Put feedback in an expander
feedback_expander = st.expander("πŸ“ Provide Feedback")
with feedback_expander:
col1, col2 = st.columns(2)
with col1:
st.markdown("**⭐ Rating Questions:**")
source_rating = st.selectbox(
"Please rate the answer provided. Higher ratings indicate better quality in the answer.",
options=[None, 1, 2, 3, 4, 5],
format_func=lambda x: "Select rating..." if x is None else f"{x} {'⭐' * x}",
key=f"source_rating_{idx}"
)
answer_rating = st.selectbox(
"Please rate the quality of the data source provided - is it sufficient to answer the question? Higher ratings indicate better quality in the source.",
options=[None, 1, 2, 3, 4, 5],
format_func=lambda x: "Select rating..." if x is None else f"{x} {'⭐' * x}",
key=f"answer_rating_{idx}"
)
with col2:
# Text feedback questions
st.markdown("**πŸ“ Detailed Feedback Questions:**")
feedback_q1 = st.text_area(
"Why is the answer wrong? Does it miss any key information?",
placeholder="Please describe any mistakes or missing information in the response...",
key=f"feedback_q1_{idx}",
height=80
)
# Submit feedback button
if st.button("Submit Feedback", key=f"submit_{idx}"):
if source_rating is not None or answer_rating is not None:
# Get current model parameters
current_embedder = get_current_model_info()["embedder_name"] if get_current_model_info() else "Unknown"
# Get the corresponding user query (should be the previous message)
user_query = st.session_state.messages[idx-1]["content"] if idx > 0 else "Unknown"
# Save feedback
success = st.session_state.feedback_manager.save_feedback(
user_query=user_query,
ai_response=message["content"],
source_rating=source_rating,
answer_rating=answer_rating,
feedback_q1=feedback_q1 or "",
embedder_model=current_embedder,
llm_model=getattr(st.session_state, 'last_model_used', 'Unknown'),
temperature=getattr(st.session_state, 'last_temperature_used', 0),
top_p=getattr(st.session_state, 'last_top_p_used', 0),
max_length=getattr(st.session_state, 'last_max_length_used', 0)
)
if success:
st.success("βœ… Thank you for your feedback!")
st.session_state.feedback_submitted.add(feedback_key)
st.rerun()
else:
st.error("❌ Failed to save feedback. Please check your Google Sheets configuration.")
else:
st.warning("⚠️ Please select at least one rating before submitting feedback.")
else:
st.success("βœ… Feedback already submitted for this response")
# User input
if user_input := st.chat_input("Ask me questions about the CANMAT depression guideline!"):
if not get_current_model_info():
st.error("❌ Please load an embedding model first from the sidebar.")
st.stop()
if not st.session_state.get("user_llm_client"):
st.error("❌ LLM client not available. Please check your API keys.")
st.stop()
st.chat_message("user").markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
# Store current model parameters for feedback
st.session_state.last_model_used = selected_model
st.session_state.last_temperature_used = temperature
st.session_state.last_top_p_used = top_p
st.session_state.last_max_length_used = max_length
# ===== latest 10 histories(5 round) =====
history = st.session_state.messages[:-1][-10:]
placeholder = st.chat_message("assistant").empty()
collected = ""
try:
t0 = time.perf_counter()
import Rag
original_client = Rag.llm_client
Rag.llm_client = st.session_state.user_llm_client
try:
results, response = depression_assistant(
user_input,
model_name=selected_model,
max_tokens=max_length,
temperature=temperature,
top_p=top_p,
stream_flag=True,
chat_history=history
)
for chunk in response:
collected += chunk
placeholder.markdown(collected)
finally:
Rag.llm_client = original_client
t1 = time.perf_counter()
response_time = t1 - t0
print(f"[Time] Retriever + Generator takes: {response_time:.2f} seconds in total.")
print(f"============== Finish R-A-Generation for Current Query {user_input} ==============")
# Save the response and sources
st.session_state.messages.append({"role": "assistant", "content": collected})
# Save sources for this message (will be at index len(messages)-1)
message_idx = len(st.session_state.messages) - 1
st.session_state.message_sources[message_idx] = results
# Log conversation to Google Sheets
try:
current_embedder = get_current_model_info()["embedder_name"] if get_current_model_info() else "Unknown"
session_id = st.session_state.get('client_key', 'Unknown')
st.session_state.feedback_manager.log_conversation(
session_id=session_id,
user_query=user_input,
ai_response=collected,
embedder_model=current_embedder,
llm_model=selected_model,
temperature=temperature,
top_p=top_p,
max_length=max_length,
response_time=response_time
)
except Exception as log_error:
print(f"Warning: Failed to log conversation to Google Sheets: {log_error}")
st.rerun()
except Exception as e:
st.error(f"❌ Error generating response: {str(e)}")
print(f"Error in main loop: {e}")
import traceback
traceback.print_exc()