zzejiao's picture
update k number to 6
c18fc4e
import json
import time
import faiss
import os
from dotenv import load_dotenv
import requests
import torch
import numpy as np
from sentence_transformers import SentenceTransformer, models
from together import Together
# Initialize global variables
db = None
referenced_tables_db = None
embedder = None
index = None
llm_client = None
import threading
_model_lock = threading.Lock()
def load_json_to_db(file_path):
with open(file_path) as f:
db = json.load(f)
return db
#------Embedding and FAISS Indexing Functions------
def make_embeddings(embedder, embedder_name,db):
"""
Make embeddings for the given database of chunks.
"""
texts = [chunk['text'] for chunk in db]
embeddings = embedder.encode(texts, convert_to_numpy=True)
return embeddings
def save_embeddings(embedder_name, embeddings):
"""
Save embeddings to a .npy file.
"""
file_path = os.path.join("data", "embeddings", f"{embedder_name.replace('/', '_')}.npy")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
np.save(file_path, embeddings)
print(f"Saved embeddings for {embedder_name}...")
def load_embeddings(embedder_name):
"""
Load embeddings from a .npy file.
"""
global embedder, db
# if embeddings already exist, load them, else make new embeddings
try:
file_path = os.path.join("data", "embeddings", f"{embedder_name.replace('/', '_')}.npy")
embeddings = np.load(file_path, allow_pickle=True)
print(f"Embeddings for {embedder_name} already exist. Loading them...")
except FileNotFoundError:
print(f"Embeddings for {embedder_name} not found. Making new embeddings...")
# print the current runtime files
print(f"Current runtime files: {os.listdir('.')}")
embeddings = make_embeddings(embedder,embedder_name, db)
save_embeddings(embedder_name, embeddings)
return embeddings
def load_embedder_with_fallbacks(embedder_name):
"""
Tries loading a SentenceTransformer model with multiple fallback strategies.
Returns the loaded model if successful. Raises RuntimeError if all strategies fail.
"""
print(f"=========Entering load_embedder_with_fallbacks()=========")
def get_best_device():
if torch.cuda.is_available():
return torch.device("cuda") # NVIDIA GPU
else:
return torch.device("cpu")
device = get_best_device()
print(f"Using device: {device}")
strategies = [
{"trust_remote_code": False, "device": device, "description": "default sentence transformer", 'class': 'SentenceTransformer'},
{"trust_remote_code": True, "device": device, "description": "sentence transformer with trust_remote_code=True", 'class': 'SentenceTransformer'},
{"description": "manual make transformer + pooling with sentenceTransformer", "class": "Manual"},
]
for i, strategy in enumerate(strategies):
try:
print(f"[Attempt {i+1}] Loading embedder '{embedder_name}' with {strategy['description']}")
if strategy["class"] == "SentenceTransformer":
kwargs = {}
if strategy.get("trust_remote_code"):
kwargs["trust_remote_code"] = True
if strategy.get("device"):
kwargs["device"] = strategy["device"]
print(f"Using device: {strategy['device']}")
model = SentenceTransformer(embedder_name, **kwargs)
elif strategy["class"] == "Manual":
word_embedding_model = models.Transformer(embedder_name)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
print(f"=========[Success] Loaded embedder with strategy: {strategy['description']} and Exit=========")
return model
except Exception as e:
print(f"=========[Failure] '{strategy['description']}' failed: {e}=========")
raise RuntimeError(f"=========All strategies failed to load embedder '{embedder_name}'=========")
# --------------Faiss index functions-------------------
def build_faiss_cosine_similarity_index(embeddings):
"""
Build a FAISS index using cosine similarity (via normalized inner product).
"""
print("Building FAISS index (cosine similarity)...")
# Step 1: Normalize embeddings to unit vectors (L2 norm = 1)
faiss.normalize_L2(embeddings)
# Step 2: Use inner product index (dot product == cosine after normalization)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
return index
def load_faiss_index(embedder_name):
"""
Load the FAISS index from a file.
"""
try:
# if file doesn't exist in the folder data/faiss_index, raise FileNotFoundError
index_file = os.path.join("data", "faiss_index", f"{embedder_name.replace('/', '_')}_index.faiss")
if not os.path.exists(index_file):
raise FileNotFoundError(f"FAISS index file {index_file} not found.")
print(f"FAISS index for {embedder_name} already exists. Loading it...")
index = faiss.read_index(index_file)
print(f"Loaded FAISS index from {index_file}.")
except FileNotFoundError:
print(f"FAISS index for {embedder_name} not found. Building new index...")
embeddings = load_embeddings(embedder_name)
index = build_faiss_cosine_similarity_index(embeddings)
save_faiss_index(embedder_name, index)
print(f"FAISS index for {embedder_name} built and saved.")
return index
def save_faiss_index(embedder_name, index):
"""
Save the FAISS index to a file.
"""
index_file = f"{embedder_name.replace('/', '_')}_index.faiss"
index_file = os.path.join("data", "faiss_index", index_file)
# 确保目录存在
os.makedirs(os.path.dirname(index_file), exist_ok=True)
faiss.write_index(index, index_file)
print(f"Saved FAISS index to {index_file}.")
# ---------------------------------
def faiss_search(query, embedder, db, index, referenced_table_db, k=6):
"""
Search for relevant chunks in the database using FAISS.
Args:
query (str): user query to search in the database
embedder (SentenceTransformer): loaded in launch_depression_assistant(), used to encode the query
db (dict): guideline database, a list of chunks with metadata, load from json file
index (faissIndex): build with faiss from the embeddings of the db, used to search for the query
referenced_table_db (dict): a list of chunks that are tables, included already but used to add the referenced tables to the results
k (int, optional): number of documents searched. Defaults to 6.
Returns:
list of dict: top k chunks from the database that are most relevant to the query, each chunk is a dict with keys: text, section, chunk_id
"""
query_embedding = embedder.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, k)
results = []
referenced_tables = set()
existed_tables = set()
for i in range(k):
if indices[0][i] != -1: # Check if the index is valid
similarity = float(distances[0][i])
# Only include results with similarity >= 0.4
if similarity >= 0.4:
results.append({
"text": db[indices[0][i]]['text'],
"section": db[indices[0][i]]['metadata']['section'],
"chunk_id": db[indices[0][i]]['metadata']['chunk_id'],
# return also the similarity score
"similarity": similarity,
})
# if this chunk has a referee_id, it is a table already, we don't need to add it again later
if db[indices[0][i]]['metadata']['referee_id']:
existed_tables.add(db[indices[0][i]]['metadata']['referee_id'])
try:
if db[indices[0][i]]['metadata']['referenced_tables']:
referenced_tables.update(db[indices[0][i]]['metadata']['referenced_tables'])
except KeyError:
continue
# existed_tables = tables that already exist in the retrieved results
# referenced_tables = tables that are referenced by chunks in the retrieved results
# table_to_add = referenced_tables - existed_tables
table_to_add = [table for table in referenced_tables if table not in existed_tables]
print(f"existed tables: {existed_tables}")
print(f"referenced tables: {referenced_tables}")
print(f"Tables to add: {table_to_add}")
# add the referenced tables in the db to the results if their referee_id is in table_to_add
i = 0
for chunk in referenced_table_db:
if chunk['metadata']['referee_id'] in table_to_add:
results.append({
"text": chunk['text'],
"section": chunk['metadata']['section'],
"chunk_id": chunk['metadata']['chunk_id'],
})
i += 1
if i == len(table_to_add):
break
return results
def load_together_llm_client():
"""
Load the Together LLM client with the provided API key.
"""
load_dotenv() # Load environment variables from .env file
return Together(api_key=os.getenv("TOGETHER_API_KEY"))
# ---------- Prompt ----------
def construct_prompt(query, faiss_results):
with open("src/system_prompt.txt", "r") as f:
system_prompt = f.read().strip()
prompt = f"""
### System Prompt
{system_prompt}
### User Query
{query}
### Clinical Guidelines Context
"""
for res in faiss_results:
prompt += f"- reference: {res['section']}\n- This paragraph is from section: {res['text']}\n"
return prompt
# ===== new feature: memory =====
def construct_prompt_with_memory(query, faiss_results, chat_history=None, history_limit=4):
print("=============Constructing prompt with memory===========")
with open("src/system_prompt.txt", "r") as f:
system_prompt = f.read().strip()
prompt = f"### System Prompt\n{system_prompt}\n\n"
if chat_history:
prompt += "### Chat History\n"
for m in chat_history[-history_limit:]:
prompt += f"{m['role'].title()}: {m['content']}\n"
prompt += "\n"
prompt += f"### User Query\n{query}\n\n"
prompt += "### Clinical Guidelines Context\n"
for res in faiss_results:
prompt += f"- reference: {res['section']}\n- This paragraph is from section: {res['text']}\n"
return prompt
def call_llm(llm_client, prompt, stream_flag=False, max_tokens=500, temperature=0.05, top_p=0.9, model_name="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"):
print(f"Calling LLM with model: {model_name}")
print(f"With parameters: max_tokens={max_tokens}, temperature={temperature}, top_p={top_p}")
try:
if stream_flag:
# For streaming mode, return a generator
def stream_generator():
response = llm_client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=True,
)
print("Streaming response received from API")
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
return stream_generator()
else:
# For non-streaming mode, return content directly
response = llm_client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=False,
)
content = response.choices[0].message.content
return content
except Exception as e:
print("Error in call_llm:", str(e))
print("Error type:", type(e))
import traceback
traceback.print_exc()
raise
def launch_depression_assistant(embedder_name, designated_client=None):
"""
Launch the depression assistant: 1.loaded database and faiss index, 2. load embedding model, 3. load llm client.
"""
global db, referenced_tables_db, embedder, index, llm_client
db = load_json_to_db("data/processed/guideline_db.json")
referenced_tables_db = load_json_to_db("data/processed/referenced_table_chunks.json")
t0 = time.perf_counter()
embedder = load_embedder_with_fallbacks(embedder_name)
t1 = time.perf_counter()
print(f"[Time] Embedding model loaded in {t1 - t0:.2f} seconds.")
index = load_faiss_index(embedder_name)
t2 = time.perf_counter()
print(f"[Time] FAISS index loaded in {t2 - t1:.2f} seconds.")
if designated_client is None:
print("No LLM client provided. Loading Together LLM client...")
try:
llm_client = load_together_llm_client()
except Exception as e:
print("------------Failed to load Together LLM client. This might be related to user access. Please manually configure your LLM client API key.------------")
else:
print("------------Using provided LLM client.------------")
llm_client = designated_client
t3 = time.perf_counter()
print(f"[Time] LLM client initiated in {t3 - t2:.2f} seconds.")
print("---------Depression Assistant is ready to use!--------------\n\n")
def depression_assistant(query, model_name="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", max_tokens=500, temperature=0.05, top_p=0.9, stream_flag=False, chat_history=None):
print(f"=========Entering depression_assistant with query: {query}=========")
global db, referenced_tables_db, embedder, index, llm_client
t1 = time.perf_counter()
results = faiss_search(query, embedder, db, index, referenced_tables_db)
t2 = time.perf_counter()
print(f"[Time] FAISS search done in {t2 - t1:.2f} seconds.")
prompt = construct_prompt_with_memory(query, results, chat_history=chat_history)
if llm_client == "Run Ollama Locally":
print(f"Running Ollama Locally with model: {model_name}, Make sure you have enough memory to run the model.")
response = call_ollama(prompt, model_name, stream_flag, max_tokens=max_tokens, temperature=temperature, top_p=top_p,)
else:
response = call_llm(llm_client, prompt, stream_flag, max_tokens=max_tokens, temperature=temperature, top_p=top_p, model_name=model_name)
return results, response
def load_queries_and_answers(query_file, answers_file):
"""
Load queries and answers from the provided files.
"""
with open(query_file, 'r') as f:
queries = f.readlines()
with open(answers_file, 'r') as f:
answers = f.readlines()
return queries, answers