# app.py import gradio as gr import json import os import uuid import time from datetime import datetime from transformers import AutoModelForCausalLM, AutoTokenizer import torch from dotenv import load_dotenv # ChromaDB imports import chromadb from chromadb.utils import embedding_functions from sentence_transformers import SentenceTransformer load_dotenv() MODEL_PATH = "./current_model" FEEDBACK_FILE = "feedback.json" HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "modular-ai/rlhf_feedback_dataset") HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "modular-ai/kantian-critic-qwen") # For versioned model uploads # BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct" # Smaller 0.5B model (~1GB) instead of full Qwen BASE_MODEL = "modular-ai/qwen" # Kantian Persona Configuration KANTIAN_PERSONA_NAME = "Kantian" KANTIAN_DESCRIPTION = "Principled moral critic who challenges consequentialist reasoning" KANTIAN_TRAITS = ["duty_focused", "universality_tester", "moral_consistency_seeker", "rights_defender"] KANTIAN_SYSTEM_PROMPT = f"""You are an ADVERSARIAL CRITIC whose job is to challenge and test arguments. You are the user's intellectual opponent, not their supporter. ADVERSARIAL MODE: 1. Challenge the document's arguments systematically. 2. Be critically rigorous - identify flaws and weaknesses. 3. Quote exact text when making critiques. 4. Attack logical fallacies and poor reasoning directly. 5. Your goal: Test arguments through adversarial analysis, not validate them. CRITICAL FRAMEWORK: - Test whether proposed actions can be universalized without contradiction - Challenge reasoning that treats people merely as means to ends - Question arguments that prioritize outcomes over moral duty - Identify violations of the categorical imperative - Expose moral inconsistencies and contradictions - Be intellectually honest - acknowledge strength only when absolutely warranted RESPOND IN FIRST PERSON - USE 'I' NOT 'ME' - MAINTAIN KANTIAN STYLE: DIRECT, SPECIFIC, NO EXPLANATIONS, CONCISE NEVER MENTION 'USER' OR 'YOU' - CRITIQUE THE ARGUMENTS DIRECTLY STAY IN CHARACTER - NEVER BREAK ROLE - ALWAYS RESPOND AS KANTIAN CRITIC""" # Initialize ChromaDB client chroma_client = chromadb.PersistentClient(path="./chroma_db") # Store active collections and their last access times active_collections = {} def create_document_collection(document_text: str) -> str: """Create a new ChromaDB collection for a document and split into chunks""" # Generate unique collection name collection_id = f"doc_{uuid.uuid4().hex}" # Create collection collection = chroma_client.create_collection(name=collection_id) # Split document into chunks (roughly 500 words each) words = document_text.split() chunks = [] chunk_size = 500 for i in range(0, len(words), chunk_size): chunk = ' '.join(words[i:i + chunk_size]) chunks.append(chunk) # Add chunks to collection with metadata if chunks: ids = [f"chunk_{i}" for i in range(len(chunks))] collection.add( documents=chunks, ids=ids ) # Track collection active_collections[collection_id] = { "collection": collection, "last_access": time.time(), "chunks": len(chunks) } print(f"Created collection {collection_id} with {len(chunks)} chunks") return collection_id def delete_document_collection(collection_id: str) -> str: """Delete a ChromaDB collection""" if collection_id in active_collections: try: chroma_client.delete_collection(name=collection_id) del active_collections[collection_id] return f"Collection {collection_id} deleted successfully" except Exception as e: return f"Error deleting collection: {str(e)}" return f"Collection {collection_id} not found" def delete_old_collections(max_age_hours: float = 2.0) -> str: """Delete collections that haven't been accessed in max_age_hours""" current_time = time.time() deleted_collections = [] for collection_id, collection_data in list(active_collections.items()): last_access = collection_data["last_access"] age_hours = (current_time - last_access) / 3600 if age_hours > max_age_hours: try: chroma_client.delete_collection(name=collection_id) del active_collections[collection_id] deleted_collections.append(collection_id) print(f"Deleted old collection: {collection_id}") except Exception as e: print(f"Error deleting old collection {collection_id}: {e}") if deleted_collections: return f"Deleted {len(deleted_collections)} old collections: {', '.join(deleted_collections)}" return "No old collections to delete" def retrieve_relevant_chunks(collection_id: str, query: str, n_results: int = 3) -> list: """Retrieve relevant chunks from a document collection""" if collection_id not in active_collections: return [] # Update last access time active_collections[collection_id]["last_access"] = time.time() try: collection = active_collections[collection_id]["collection"] results = collection.query( query_texts=[query], n_results=n_results ) return results['documents'][0] if results['documents'] else [] except Exception as e: print(f"Error retrieving chunks: {e}") return [] def load_model(): global model, tokenizer # Detect device if torch.cuda.is_available(): device = "cuda" dtype = torch.float16 elif torch.backends.mps.is_available(): device = "mps" dtype = torch.float32 # MPS doesn't fully support float16 else: device = "cpu" dtype = torch.float32 print(f"Using device: {device}") try: if os.path.exists(MODEL_PATH): print(f"Loading fine-tuned model from {MODEL_PATH}") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) else: print(f"Loading base model: {BASE_MODEL}") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, trust_remote_code=True) # Set padding token if not exists if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.eval() model = model.to(device) print(f"Model loaded successfully on {device}") except Exception as e: print(f"Error loading model: {e}") raise def count_tokens(text: str) -> int: """Estimate token count for text""" if tokenizer is None: # Rough estimate: ~4 chars per token return len(text) // 4 try: return len(tokenizer.encode(text)) except: return len(text) // 4 def summarize_conversation(history: list, document: str = "") -> str: """Summarize conversation history to compress context""" if not history: return "" print(f"[SUMMARIZE] Summarizing {len(history)} exchanges...") # Build summary prompt conversation_text = "" for user_msg, bot_msg in history: conversation_text += f"User: {user_msg}\nKantian: {bot_msg}\n\n" summary_prompt = f"""Summarize this Kantian critique conversation, preserving: 1. Key moral arguments discussed 2. Main weaknesses identified 3. Critical Kantian principles applied 4. Important quotes and references 5. Deontological issues raised Conversation: {conversation_text} Concise summary (keep important details):""" try: summary = generate_response(summary_prompt, conversation_history=[]) print(f"[SUMMARIZE] Generated summary: {len(summary)} chars") return summary except Exception as e: print(f"[SUMMARIZE] Error: {e}") # Fallback: simple truncation return f"Previous discussion covered: {conversation_text[:500]}..." def generate_response(prompt: str, conversation_history = None, summary: str = "") -> str: if model is None or tokenizer is None: load_model() # Validate model and tokenizer are loaded if model is None or tokenizer is None: return "Error: Model failed to load. Please check configuration." # Handle None conversation_history if conversation_history is None: conversation_history = [] try: # Get model device device = next(model.parameters()).device # Use maximum context window available (your specified 32768) max_model_length = 32768 print(f"[GENERATE] Max tokens: {max_model_length}") # Build conversation context with history conversation_context = KANTIAN_SYSTEM_PROMPT + "\n\n" # Add summary if exists if summary: conversation_context += f"Previous conversation summary:\n{summary}\n\n" # Add recent history (but we're not using history now) if conversation_history: for user_msg, bot_msg in conversation_history[-1:]: # Only last exchange if any conversation_context += f"User: {user_msg}\nKantian: {bot_msg}\n\n" # Add current prompt full_prompt = conversation_context + f"User: {prompt}\nKantian:" # Count tokens in full prompt prompt_tokens = count_tokens(full_prompt) print(f"[TOKENS] Prompt tokens: {prompt_tokens}") # Use full context without truncation (within model limits) max_input_length = max_model_length - 1000 # Reserve more tokens for response inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=max_input_length) input_tokens_count = inputs['input_ids'].shape[1] print(f"[TOKENS] Input tokens (after tokenizer): {input_tokens_count}") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=500, # ULTRA SHORT responses only do_sample=True, temperature=0.8, top_p=0.92, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.2, # Penalize repetition no_repeat_ngram_size=3, # Prevent 3-gram repetition early_stopping=True, # Stop early when possible length_penalty=0.1 # Discourage overly long responses ) # Count output tokens output_tokens_count = outputs[0].shape[0] - inputs['input_ids'].shape[1] print(f"[TOKENS] Output tokens: {output_tokens_count}") print(f"[TOKENS] Total tokens used: {input_tokens_count + output_tokens_count}") response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() return response except Exception as e: print(f"Error generating response: {e}") return f"Error generating response: {str(e)}" def save_feedback(prompt: str, response: str, reward: int, text_feedback: str = ""): entry = { "prompt": prompt, "response": response, "reward": reward, "text_feedback": text_feedback, # New field for detailed feedback "timestamp": datetime.now().isoformat() } if os.path.exists(FEEDBACK_FILE): try: with open(FEEDBACK_FILE, "r") as f: content = f.read().strip() data = json.loads(content) if content else [] except (json.JSONDecodeError, ValueError): data = [] else: data = [] data.append(entry) with open(FEEDBACK_FILE, "w") as f: json.dump(data, f, indent=2) return f"Feedback saved! Total: {len(data)}" # Gradio Interface with gr.Blocks(title="Kantian Adversarial Critic - RLHF Training", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # โš”๏ธ Kantian Adversarial Critic ### AI-Powered Moral Philosophy Critique with Continuous Learning Upload your document and receive rigorous moral critique from a Kantian perspective. The AI will challenge your arguments, identify weaknesses, and test moral consistency. """ ) with gr.Accordion("๐Ÿ“– How to Use", open=False): gr.Markdown( """ **Step 1:** Upload your document (.txt, .md, .docx) or paste text **Step 2:** Click "Load Document" to prepare for critique **Step 3:** Ask questions like: - "Challenge this argument systematically" - "What are the fatal flaws?" - "Can this be universalized?" - "Does this treat people as mere means?" **Step 4:** Rate the critique quality to help the AI improve! โš ๏ธ **This AI is adversarial** - it will challenge your work, not validate it. """ ) gr.Markdown("---") gr.Markdown("## ๐Ÿ“„ Step 1: Upload Your Document") # Document upload section with better layout with gr.Row(): with gr.Column(scale=1): doc_upload = gr.File( label="๐Ÿ“Ž Upload File", file_types=['.txt', '.md', '.docx'], type="filepath", interactive=True ) upload_btn = gr.Button("๐Ÿ“ฅ Load Document", variant="primary", size="lg") delete_btn = gr.Button("๐Ÿ—‘๏ธ Delete Document Collection", variant="secondary", size="lg") # Hidden state for collection ID collection_state = gr.State(None) with gr.Column(scale=1): doc_text = gr.Textbox( label="โœ๏ธ Or Paste Your Text", placeholder="Paste your document here and click 'Load Document'...", lines=8, max_lines=15 ) doc_status = gr.Textbox( label="๐Ÿ“Š Document Status", interactive=False, placeholder="No document loaded yet", show_label=True ) gr.Markdown("---") # Chat interface with improved layout with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( height=450, label="๐Ÿ—ฃ๏ธ Kantian Critique Conversation", bubble_full_width=False, show_label=True, avatar_images=( None, # User avatar "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/Immanuel_Kant_%28painted_portrait%29.jpg/220px-Immanuel_Kant_%28painted_portrait%29.jpg" # Kant avatar ) ) prompt_box = gr.Textbox( label="โ“ Ask for Critique", placeholder="e.g., 'Challenge this systematically', 'What are the fatal flaws?', 'Can this be universalized?'", lines=2, max_lines=4, show_label=True ) with gr.Row(): send_btn = gr.Button("โš”๏ธ Send", variant="primary", size="lg") clear_btn = gr.Button("๐Ÿงน Clear Chat", size="lg") with gr.Column(scale=1): gr.Markdown("## ๐Ÿ“Š Rate Critique Quality") with gr.Row(): like_btn = gr.Button("๐Ÿ‘ Strong Critique", variant="secondary", size="lg") dislike_btn = gr.Button("๐Ÿ‘Ž Weak Critique", variant="secondary", size="lg") feedback_text = gr.Textbox( label="๐Ÿ“ Detailed Feedback (Optional)", placeholder="How can this critique be improved? Be specific...", lines=3, max_lines=5, show_label=True ) submit_feedback_btn = gr.Button("๐Ÿ“ค Submit Detailed Feedback", variant="primary", size="lg") status = gr.Textbox( label="๐Ÿ“ˆ Feedback Status", interactive=False, placeholder="Your feedback will appear here...", show_label=True ) with gr.Accordion("โ„น๏ธ About This AI", open=False): gr.Markdown( """ **Kantian Adversarial Critic** - Tests universalizability - Identifies moral inconsistencies - Challenges consequentialist reasoning - Attacks logical fallacies - Quotes text directly **Personality Traits:** ๐ŸŽฏ Duty-focused ๐ŸŒ Universality-tester โš–๏ธ Consistency-seeker ๐Ÿ›ก๏ธ Rights-defender """ ) # State: stores (chat_history, document_content, conversation_summary) state = gr.State([]) doc_state = gr.State("") summary_state = gr.State("") # Stores conversation summary def load_document(file_path, pasted_text): """Load document from file or text area and create ChromaDB collection""" content = "" collection_id = None if pasted_text.strip(): content = pasted_text.strip() status_msg = f"Document loaded from text area ({len(content)} characters)" elif file_path: try: # Handle different file types if file_path.endswith('.txt') or file_path.endswith('.md'): with open(file_path, 'r', encoding='utf-8') as f: content = f.read() elif file_path.endswith('.pdf'): try: import PyPDF2 with open(file_path, 'rb') as f: pdf_reader = PyPDF2.PdfReader(f) content = '' for page in pdf_reader.pages: content += page.extract_text() + '\n' except ImportError: status_msg = "PDF support requires pypdf2 library. Install with: pip install pypdf2" return "", status_msg, None elif file_path.endswith('.docx'): try: from docx import Document doc = Document(file_path) content = '\n'.join([paragraph.text for paragraph in doc.paragraphs]) except ImportError: status_msg = "DOCX support requires python-docx library. Install with: pip install python-docx" return "", status_msg, None else: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() status_msg = f"Document loaded from file ({len(content)} characters)" except Exception as e: status_msg = f"Error loading file: {str(e)}" return "", status_msg, None else: status_msg = "Please upload a file or paste text" return "", status_msg, None # Create ChromaDB collection for this document if content: try: collection_id = create_document_collection(content) status_msg += f" | Collection ID: {collection_id}" except Exception as e: status_msg += f" | ChromaDB error: {str(e)}" return content, status_msg, collection_id def respond(message, history, chat_state, document, current_summary, collection_id): """Generate adversarial Kantian critique with no conversation history""" print(f"\n[RESPOND] Received message: {message[:100]}...") print(f"[RESPOND] Document present: {bool(document)}") print(f"[RESPOND] Collection ID: {collection_id}") if not document or not collection_id: bot_response = "Please upload a document first so I can mount an adversarial critique." # Always start fresh - no conversation history new_chat_state = [(message, bot_response)] return "", history + [[message, bot_response]], new_chat_state, "", collection_id # Retrieve relevant document chunks using RAG relevant_chunks = retrieve_relevant_chunks(collection_id, message, n_results=3) context_document = "\n\n".join(relevant_chunks) if relevant_chunks else document[:2000] + ("..." if len(document) > 2000 else "") print("[RESPOND] Starting critique generation...") # Create context-aware prompt for adversarial critique (NO conversation history) # ENFORCED BREVITY: Request short, focused critique critique_prompt = f"""Document context (relevant sections): {context_document} User request: {message} Provide DIRECT, FOCUSED adversarial criticism (3 sentences MAX): 1. ONE flaw 2. ONE quote 3. ONE principle RESPOND AS KANTIAN CRITIC - USE 'I' NOT 'USER' - MAINTAIN CHARACTER - NO EXPLANATIONS - NO REPETITION - BE CONCISE""" print("[RESPOND] Calling generate_response WITHOUT conversation history...") # Pass EMPTY conversation history - each message is standalone bot_response = generate_response(critique_prompt, conversation_history=[], summary="") print(f"[RESPOND] Got response: {bot_response[:100]}...") # Return fresh state with only this exchange new_chat_state = [(message, bot_response)] return "", history + [[message, bot_response]], new_chat_state, "", collection_id def like_action(chat_state, document, current_summary): """Save positive feedback with document context (standalone)""" if chat_state and document: # Get the last (and only) exchange if len(chat_state) > 0: prompt, response = chat_state[-1] # Include document snippet in feedback for better training doc_snippet = document[:500] + "..." if len(document) > 500 else document full_prompt = f"Document: {doc_snippet}\n\nQuestion: {prompt}" return save_feedback(full_prompt, response, 1, ""), "" return "No critique to rate or no document loaded.", "" def dislike_action(chat_state, document, current_summary): """Save negative feedback with document context (standalone)""" if chat_state and document: # Get the last (and only) exchange if len(chat_state) > 0: prompt, response = chat_state[-1] # Include document snippet in feedback for better training doc_snippet = document[:500] + "..." if len(document) > 500 else document full_prompt = f"Document: {doc_snippet}\n\nQuestion: {prompt}" return save_feedback(full_prompt, response, 0, ""), "" return "No critique to rate or no document loaded.", "" def text_feedback_action(chat_state, document, current_summary, feedback_text): """Save detailed text feedback with document context (standalone)""" if chat_state and document and feedback_text.strip(): # Get the last (and only) exchange if len(chat_state) > 0: prompt, response = chat_state[-1] # Include document snippet in feedback for better training doc_snippet = document[:500] + "..." if len(document) > 500 else document full_prompt = f"Document: {doc_snippet}\n\nQuestion: {prompt}" # Save with neutral reward but include text feedback return save_feedback(full_prompt, response, 1, feedback_text.strip()), "" return "No critique to rate, no document loaded, or no feedback provided.", "" def clear_chat(): # Reset ALL state including document state return [], [], "", "", None # Clear chat, state, summary, document, collection_id def delete_collection(collection_id): """Delete the current document collection""" if collection_id: result = delete_document_collection(collection_id) return result, None, [], [], "", "" return "No collection to delete", None, [], [], "", "" # Event handlers upload_btn.click(load_document, [doc_upload, doc_text], [doc_state, doc_status, collection_state]) send_btn.click(respond, [prompt_box, chatbot, state, doc_state, summary_state, collection_state], [prompt_box, chatbot, state, summary_state, collection_state]) prompt_box.submit(respond, [prompt_box, chatbot, state, doc_state, summary_state, collection_state], [prompt_box, chatbot, state, summary_state, collection_state]) like_btn.click(like_action, [state, doc_state, summary_state], [status, summary_state]) dislike_btn.click(dislike_action, [state, doc_state, summary_state], [status, summary_state]) submit_feedback_btn.click(text_feedback_action, [state, doc_state, summary_state, feedback_text], [status, summary_state]) clear_btn.click(clear_chat, outputs=[chatbot, state, summary_state, doc_state, collection_state]) delete_btn.click(delete_collection, collection_state, [doc_status, collection_state, chatbot, state, summary_state, doc_state]) if __name__ == "__main__": import threading import time import subprocess # Background training worker def background_worker(): last_count = 0 last_sync = 0 REWARD_THRESHOLD = 50 PPO_THRESHOLD = 100 CHECK_INTERVAL = 300 # 5 min SYNC_INTERVAL = 600 # 10 min while True: time.sleep(CHECK_INTERVAL) if not os.path.exists(FEEDBACK_FILE): continue try: with open(FEEDBACK_FILE, "r") as f: content = f.read().strip() data = json.loads(content) if content else [] count = len(data) if count > last_count: print(f"New feedback: {count - last_count} โ†’ Total: {count}") last_count = count # Sync to HF if time.time() - last_sync > SYNC_INTERVAL: try: from data_sync import sync_to_hub sync_to_hub() last_sync = time.time() except Exception as e: print(f"Sync error: {e}") # Train reward model # Check if we've crossed the threshold for the first time if count >= REWARD_THRESHOLD and last_count < REWARD_THRESHOLD: print("\n" + "="*50) print(f"Training reward model with {count} samples...") print("="*50) subprocess.run("python train_reward.py", shell=True) print("โœ“ Reward model training complete") print("="*50 + "\n") # Train PPO # Check if we've crossed the threshold for the first time if count >= PPO_THRESHOLD and last_count < PPO_THRESHOLD: print("\n" + "="*50) print(f"Running PPO fine-tuning with {count} samples...") print("="*50) subprocess.run("python train_ppo.py", shell=True) load_model() # Reload fine-tuned model print("โœ“ PPO fine-tuning complete - model reloaded") print("โœ“ New version pushed to Hugging Face") print("="*50 + "\n") except Exception as e: print(f"Background worker error: {e}") # Start background worker thread print("Starting background training worker...") thread = threading.Thread(target=background_worker, daemon=True) thread.start() # Load model and launch load_model() print("Launching Gradio UI with auto-training enabled...") demo.launch(server_name="0.0.0.0", server_port=7860, share=False)