Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import gradio as gr | |
| import redis | |
| import numpy as np | |
| import json | |
| from datetime import timedelta | |
| from openai import AzureOpenAI | |
| from sentence_transformers import SentenceTransformer | |
| # ----------------------- | |
| # Configuration | |
| # ----------------------- | |
| REDIS_HOST = "redis-14417.c13.us-east-1-3.ec2.cloud.redislabs.com" | |
| REDIS_PORT = 14417 | |
| REDIS_USER = "default" | |
| REDIS_PASSWORD = os.getenv("REDIS_PASSWORD") | |
| AZURE_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "").strip() | |
| AZURE_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "").strip() | |
| AZURE_API_VERSION = "2025-01-01-preview" | |
| CHAT_DEPLOYMENT = "gpt-4.1" | |
| # Cache TTL (2 days) | |
| CACHE_TTL = int(timedelta(days=2).total_seconds()) | |
| # Matching thresholds | |
| PRIMARY_THRESHOLD = 0.90 # for same-language matches | |
| FALLBACK_THRESHOLD = 0.95 # for language-agnostic fallback (very strict) | |
| # ----------------------- | |
| # Clients / Models | |
| # ----------------------- | |
| redis_client = redis.Redis( | |
| host=REDIS_HOST, | |
| port=REDIS_PORT, | |
| decode_responses=True, | |
| username=REDIS_USER, | |
| password=REDIS_PASSWORD, | |
| ) | |
| client = AzureOpenAI( | |
| api_key=AZURE_API_KEY, | |
| api_version=AZURE_API_VERSION, | |
| azure_endpoint=AZURE_ENDPOINT, | |
| ) | |
| # Embedding model (multilingual, small & strong) | |
| embedder = SentenceTransformer("intfloat/multilingual-e5-small") | |
| # ----------------------- | |
| # Helpers | |
| # ----------------------- | |
| def detect_language_tag(text: str): | |
| """Return a language tag string (lowercase) or None.""" | |
| t = text.lower() | |
| patterns = [ | |
| (r'\bjava\b', "java"), | |
| (r'\bpython\b', "python"), | |
| (r'\b(c\+\+|cpp)\b', "cpp"), | |
| (r'\bc#\b|\bcsharp\b', "csharp"), | |
| (r'\bjavascript\b|\bjs\b', "javascript"), | |
| (r'\b(go|golang)\b', "go"), | |
| (r'\bruby\b', "ruby"), | |
| (r'\bphp\b', "php"), | |
| (r'\bscala\b', "scala"), | |
| (r'\br\b', "r"), | |
| # C detection is tricky; look for " in c", " c language", or standalone " c " | |
| (r'\b in c\b|\bc language\b|\b c \b', "c"), | |
| ] | |
| for pat, tag in patterns: | |
| if re.search(pat, t): | |
| return tag | |
| return None | |
| def build_embedding_input(text: str, lang_tag: str | None): | |
| """Create the text to embed: include language tag prefix if present.""" | |
| if lang_tag: | |
| return f"{lang_tag.upper()}: {text}" | |
| return text | |
| def get_embedding(text: str) -> np.ndarray: | |
| vec = embedder.encode(text, convert_to_numpy=True) | |
| return vec.astype(np.float32) | |
| def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float: | |
| # safe guard against zero vectors | |
| n1 = np.linalg.norm(vec1) | |
| n2 = np.linalg.norm(vec2) | |
| if n1 == 0 or n2 == 0: | |
| return 0.0 | |
| return float(np.dot(vec1, vec2) / (n1 * n2)) | |
| # ----------------------- | |
| # Cache functions | |
| # ----------------------- | |
| def store_cache(user_id: str, user_input: str, output: str): | |
| lang = detect_language_tag(user_input) | |
| embed_text = build_embedding_input(user_input, lang) | |
| vec = get_embedding(embed_text).tolist() | |
| cache_key = f"cache:{user_id}" | |
| store_key = (f"{lang}:" + user_input) if lang else user_input | |
| payload = { | |
| "orig": user_input, | |
| "embedding": vec, | |
| "output": output, | |
| "lang": lang, | |
| } | |
| redis_client.hset(cache_key, store_key, json.dumps(payload)) | |
| redis_client.expire(cache_key, CACHE_TTL) | |
| def search_cache(user_id: str, user_input: str, primary_threshold=PRIMARY_THRESHOLD, fallback_threshold=FALLBACK_THRESHOLD): | |
| cache_key = f"cache:{user_id}" | |
| entries = redis_client.hgetall(cache_key) | |
| if not entries: | |
| return None | |
| # detect language and make embedding with same prefix logic | |
| detected_lang = detect_language_tag(user_input) | |
| query_embed_text = build_embedding_input(user_input, detected_lang) | |
| query_vec = get_embedding(query_embed_text) | |
| # 1) Try same-language matches (if language detected) | |
| best_score = -1.0 | |
| best_output = None | |
| if detected_lang: | |
| for _, val in entries.items(): | |
| entry = json.loads(val) | |
| if entry.get("lang") != detected_lang: | |
| continue | |
| vec = np.array(entry["embedding"], dtype=np.float32) | |
| score = cosine_similarity(query_vec, vec) | |
| if score > best_score: | |
| best_score, best_output = score, entry["output"] | |
| if best_score >= primary_threshold: | |
| return best_output | |
| # 2) Try language-agnostic entries (lang == None) | |
| best_score = -1.0 | |
| best_output = None | |
| for _, val in entries.items(): | |
| entry = json.loads(val) | |
| if entry.get("lang") is not None: | |
| continue | |
| vec = np.array(entry["embedding"], dtype=np.float32) | |
| score = cosine_similarity(query_vec, vec) | |
| if score > best_score: | |
| best_score, best_output = score, entry["output"] | |
| if best_score >= fallback_threshold: | |
| return best_output | |
| # 3) Final fallback: search any language but require very high similarity | |
| best_score = -1.0 | |
| best_output = None | |
| for _, val in entries.items(): | |
| entry = json.loads(val) | |
| vec = np.array(entry["embedding"], dtype=np.float32) | |
| score = cosine_similarity(query_vec, vec) | |
| if score > best_score: | |
| best_score, best_output = score, entry["output"] | |
| if best_score >= fallback_threshold: | |
| return best_output | |
| return None | |
| def clear_user_cache(user_id: str): | |
| redis_client.delete(f"cache:{user_id}") | |
| def view_user_cache(user_id: str): | |
| cache_key = f"cache:{user_id}" | |
| entries = redis_client.hgetall(cache_key) | |
| if not entries: | |
| return "⚠️ No cache stored." | |
| lines = [] | |
| for k, v in entries.items(): | |
| entry = json.loads(v) | |
| lang = entry.get("lang") or "general" | |
| q = entry.get("orig", k) | |
| a = entry.get("output", "") | |
| lines.append(f"**Lang:** {lang}\n**Q:** {q}\n**A:** {a}") | |
| return "\n\n---\n\n".join(lines) | |
| # ----------------------- | |
| # Chat logic | |
| # ----------------------- | |
| def chat_with_ai(user_id: str, user_input: str): | |
| if not user_input or not user_id: | |
| return "Please set a username and type something." | |
| # 1) semantic cache search (language-aware) | |
| cached = search_cache(user_id, user_input) | |
| if cached: | |
| return f"[From Redis] {cached}" | |
| # 2) fallback to Azure OpenAI | |
| response = client.chat.completions.create( | |
| model=CHAT_DEPLOYMENT, | |
| messages=[{"role": "user", "content": user_input}], | |
| temperature=0.8, | |
| max_tokens=700, | |
| ) | |
| output = response.choices[0].message.content.strip() | |
| # store with language-aware embedding | |
| store_cache(user_id, user_input, output) | |
| return f"[From OpenAI] {output}" | |
| # ----------------------- | |
| # Gradio UI | |
| # ----------------------- | |
| with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat (Lang-aware)") as demo: | |
| gr.Markdown("# 💬 Azure OpenAI + Redis Cloud (Language-aware Semantic Cache)") | |
| user_id_state = gr.State("") | |
| with gr.Row(): | |
| user_id_input = gr.Textbox(label="Enter Username (only once)", placeholder="Your username") | |
| save_user = gr.Button("✅ Save Username") | |
| user_status = gr.Markdown("") | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(type="messages") | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Type your message here...") | |
| send = gr.Button("Send") | |
| with gr.Row(): | |
| clear = gr.Button("🧹 Clear My Cache") | |
| view = gr.Button("👀 View My Cache") | |
| cache_output = gr.Markdown("") | |
| def set_user_id(uid: str): | |
| uid = uid.strip() | |
| if not uid: | |
| return "", "⚠️ Please enter a non-empty username." | |
| return uid, f"✅ Username set as **{uid}**" | |
| def respond(message, history, user_id): | |
| if not user_id: | |
| return history, "⚠️ Please set username first!" | |
| bot_reply = chat_with_ai(user_id, message) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": bot_reply}) | |
| return history, "" | |
| def clear_cache_ui(user_id, history): | |
| if not user_id: | |
| return history, "⚠️ Please set username first!" | |
| clear_user_cache(user_id) | |
| return [], f"✅ Cache cleared for {user_id}" | |
| def view_cache_ui(user_id): | |
| if not user_id: | |
| return "⚠️ Please set username first!" | |
| return view_user_cache(user_id) | |
| save_user.click(set_user_id, user_id_input, [user_id_state, user_status]) | |
| send.click(respond, [msg, chatbot, user_id_state], [chatbot, msg]) | |
| msg.submit(respond, [msg, chatbot, user_id_state], [chatbot, msg]) | |
| clear.click(clear_cache_ui, [user_id_state, chatbot], [chatbot, cache_output]) | |
| view.click(view_cache_ui, user_id_state, cache_output) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, pwa=True) | |