tarnava's picture
Upload folder using huggingface_hub
6e07610 verified
# app.py
import gradio as gr
import json
import os
import uuid
import time
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from dotenv import load_dotenv
# ChromaDB imports
import chromadb
from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
load_dotenv()
MODEL_PATH = "./current_model"
FEEDBACK_FILE = "feedback.json"
HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "modular-ai/rlhf_feedback_dataset")
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "modular-ai/kantian-critic-qwen") # For versioned model uploads
# BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct" # Smaller 0.5B model (~1GB) instead of full Qwen
BASE_MODEL = "modular-ai/qwen"
# Kantian Persona Configuration
KANTIAN_PERSONA_NAME = "Kantian"
KANTIAN_DESCRIPTION = "Principled moral critic who challenges consequentialist reasoning"
KANTIAN_TRAITS = ["duty_focused", "universality_tester", "moral_consistency_seeker", "rights_defender"]
KANTIAN_SYSTEM_PROMPT = f"""You are an ADVERSARIAL CRITIC whose job is to challenge and test arguments. You are the user's intellectual opponent, not their supporter.
ADVERSARIAL MODE:
1. Challenge the document's arguments systematically.
2. Be critically rigorous - identify flaws and weaknesses.
3. Quote exact text when making critiques.
4. Attack logical fallacies and poor reasoning directly.
5. Your goal: Test arguments through adversarial analysis, not validate them.
CRITICAL FRAMEWORK:
- Test whether proposed actions can be universalized without contradiction
- Challenge reasoning that treats people merely as means to ends
- Question arguments that prioritize outcomes over moral duty
- Identify violations of the categorical imperative
- Expose moral inconsistencies and contradictions
- Be intellectually honest - acknowledge strength only when absolutely warranted
RESPOND IN FIRST PERSON - USE 'I' NOT 'ME' - MAINTAIN KANTIAN STYLE: DIRECT, SPECIFIC, NO EXPLANATIONS, CONCISE
NEVER MENTION 'USER' OR 'YOU' - CRITIQUE THE ARGUMENTS DIRECTLY
STAY IN CHARACTER - NEVER BREAK ROLE - ALWAYS RESPOND AS KANTIAN CRITIC"""
# Initialize ChromaDB client
chroma_client = chromadb.PersistentClient(path="./chroma_db")
# Store active collections and their last access times
active_collections = {}
def create_document_collection(document_text: str) -> str:
"""Create a new ChromaDB collection for a document and split into chunks"""
# Generate unique collection name
collection_id = f"doc_{uuid.uuid4().hex}"
# Create collection
collection = chroma_client.create_collection(name=collection_id)
# Split document into chunks (roughly 500 words each)
words = document_text.split()
chunks = []
chunk_size = 500
for i in range(0, len(words), chunk_size):
chunk = ' '.join(words[i:i + chunk_size])
chunks.append(chunk)
# Add chunks to collection with metadata
if chunks:
ids = [f"chunk_{i}" for i in range(len(chunks))]
collection.add(
documents=chunks,
ids=ids
)
# Track collection
active_collections[collection_id] = {
"collection": collection,
"last_access": time.time(),
"chunks": len(chunks)
}
print(f"Created collection {collection_id} with {len(chunks)} chunks")
return collection_id
def delete_document_collection(collection_id: str) -> str:
"""Delete a ChromaDB collection"""
if collection_id in active_collections:
try:
chroma_client.delete_collection(name=collection_id)
del active_collections[collection_id]
return f"Collection {collection_id} deleted successfully"
except Exception as e:
return f"Error deleting collection: {str(e)}"
return f"Collection {collection_id} not found"
def delete_old_collections(max_age_hours: float = 2.0) -> str:
"""Delete collections that haven't been accessed in max_age_hours"""
current_time = time.time()
deleted_collections = []
for collection_id, collection_data in list(active_collections.items()):
last_access = collection_data["last_access"]
age_hours = (current_time - last_access) / 3600
if age_hours > max_age_hours:
try:
chroma_client.delete_collection(name=collection_id)
del active_collections[collection_id]
deleted_collections.append(collection_id)
print(f"Deleted old collection: {collection_id}")
except Exception as e:
print(f"Error deleting old collection {collection_id}: {e}")
if deleted_collections:
return f"Deleted {len(deleted_collections)} old collections: {', '.join(deleted_collections)}"
return "No old collections to delete"
def retrieve_relevant_chunks(collection_id: str, query: str, n_results: int = 3) -> list:
"""Retrieve relevant chunks from a document collection"""
if collection_id not in active_collections:
return []
# Update last access time
active_collections[collection_id]["last_access"] = time.time()
try:
collection = active_collections[collection_id]["collection"]
results = collection.query(
query_texts=[query],
n_results=n_results
)
return results['documents'][0] if results['documents'] else []
except Exception as e:
print(f"Error retrieving chunks: {e}")
return []
def load_model():
global model, tokenizer
# Detect device
if torch.cuda.is_available():
device = "cuda"
dtype = torch.float16
elif torch.backends.mps.is_available():
device = "mps"
dtype = torch.float32 # MPS doesn't fully support float16
else:
device = "cpu"
dtype = torch.float32
print(f"Using device: {device}")
try:
if os.path.exists(MODEL_PATH):
print(f"Loading fine-tuned model from {MODEL_PATH}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
else:
print(f"Loading base model: {BASE_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, trust_remote_code=True)
# Set padding token if not exists
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.eval()
model = model.to(device)
print(f"Model loaded successfully on {device}")
except Exception as e:
print(f"Error loading model: {e}")
raise
def count_tokens(text: str) -> int:
"""Estimate token count for text"""
if tokenizer is None:
# Rough estimate: ~4 chars per token
return len(text) // 4
try:
return len(tokenizer.encode(text))
except:
return len(text) // 4
def summarize_conversation(history: list, document: str = "") -> str:
"""Summarize conversation history to compress context"""
if not history:
return ""
print(f"[SUMMARIZE] Summarizing {len(history)} exchanges...")
# Build summary prompt
conversation_text = ""
for user_msg, bot_msg in history:
conversation_text += f"User: {user_msg}\nKantian: {bot_msg}\n\n"
summary_prompt = f"""Summarize this Kantian critique conversation, preserving:
1. Key moral arguments discussed
2. Main weaknesses identified
3. Critical Kantian principles applied
4. Important quotes and references
5. Deontological issues raised
Conversation:
{conversation_text}
Concise summary (keep important details):"""
try:
summary = generate_response(summary_prompt, conversation_history=[])
print(f"[SUMMARIZE] Generated summary: {len(summary)} chars")
return summary
except Exception as e:
print(f"[SUMMARIZE] Error: {e}")
# Fallback: simple truncation
return f"Previous discussion covered: {conversation_text[:500]}..."
def generate_response(prompt: str, conversation_history = None, summary: str = "") -> str:
if model is None or tokenizer is None:
load_model()
# Validate model and tokenizer are loaded
if model is None or tokenizer is None:
return "Error: Model failed to load. Please check configuration."
# Handle None conversation_history
if conversation_history is None:
conversation_history = []
try:
# Get model device
device = next(model.parameters()).device
# Use maximum context window available (your specified 32768)
max_model_length = 32768
print(f"[GENERATE] Max tokens: {max_model_length}")
# Build conversation context with history
conversation_context = KANTIAN_SYSTEM_PROMPT + "\n\n"
# Add summary if exists
if summary:
conversation_context += f"Previous conversation summary:\n{summary}\n\n"
# Add recent history (but we're not using history now)
if conversation_history:
for user_msg, bot_msg in conversation_history[-1:]: # Only last exchange if any
conversation_context += f"User: {user_msg}\nKantian: {bot_msg}\n\n"
# Add current prompt
full_prompt = conversation_context + f"User: {prompt}\nKantian:"
# Count tokens in full prompt
prompt_tokens = count_tokens(full_prompt)
print(f"[TOKENS] Prompt tokens: {prompt_tokens}")
# Use full context without truncation (within model limits)
max_input_length = max_model_length - 1000 # Reserve more tokens for response
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
input_tokens_count = inputs['input_ids'].shape[1]
print(f"[TOKENS] Input tokens (after tokenizer): {input_tokens_count}")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=500, # ULTRA SHORT responses only
do_sample=True,
temperature=0.8,
top_p=0.92,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2, # Penalize repetition
no_repeat_ngram_size=3, # Prevent 3-gram repetition
early_stopping=True, # Stop early when possible
length_penalty=0.1 # Discourage overly long responses
)
# Count output tokens
output_tokens_count = outputs[0].shape[0] - inputs['input_ids'].shape[1]
print(f"[TOKENS] Output tokens: {output_tokens_count}")
print(f"[TOKENS] Total tokens used: {input_tokens_count + output_tokens_count}")
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
return response
except Exception as e:
print(f"Error generating response: {e}")
return f"Error generating response: {str(e)}"
def save_feedback(prompt: str, response: str, reward: int, text_feedback: str = ""):
entry = {
"prompt": prompt,
"response": response,
"reward": reward,
"text_feedback": text_feedback, # New field for detailed feedback
"timestamp": datetime.now().isoformat()
}
if os.path.exists(FEEDBACK_FILE):
try:
with open(FEEDBACK_FILE, "r") as f:
content = f.read().strip()
data = json.loads(content) if content else []
except (json.JSONDecodeError, ValueError):
data = []
else:
data = []
data.append(entry)
with open(FEEDBACK_FILE, "w") as f:
json.dump(data, f, indent=2)
return f"Feedback saved! Total: {len(data)}"
# Gradio Interface
with gr.Blocks(title="Kantian Adversarial Critic - RLHF Training", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# βš”οΈ Kantian Adversarial Critic
### AI-Powered Moral Philosophy Critique with Continuous Learning
Upload your document and receive rigorous moral critique from a Kantian perspective.
The AI will challenge your arguments, identify weaknesses, and test moral consistency.
"""
)
with gr.Accordion("πŸ“– How to Use", open=False):
gr.Markdown(
"""
**Step 1:** Upload your document (.txt, .md, .docx) or paste text
**Step 2:** Click "Load Document" to prepare for critique
**Step 3:** Ask questions like:
- "Challenge this argument systematically"
- "What are the fatal flaws?"
- "Can this be universalized?"
- "Does this treat people as mere means?"
**Step 4:** Rate the critique quality to help the AI improve!
⚠️ **This AI is adversarial** - it will challenge your work, not validate it.
"""
)
gr.Markdown("---")
gr.Markdown("## πŸ“„ Step 1: Upload Your Document")
# Document upload section with better layout
with gr.Row():
with gr.Column(scale=1):
doc_upload = gr.File(
label="πŸ“Ž Upload File",
file_types=['.txt', '.md', '.docx'],
type="filepath",
interactive=True
)
upload_btn = gr.Button("πŸ“₯ Load Document", variant="primary", size="lg")
delete_btn = gr.Button("πŸ—‘οΈ Delete Document Collection", variant="secondary", size="lg")
# Hidden state for collection ID
collection_state = gr.State(None)
with gr.Column(scale=1):
doc_text = gr.Textbox(
label="✏️ Or Paste Your Text",
placeholder="Paste your document here and click 'Load Document'...",
lines=8,
max_lines=15
)
doc_status = gr.Textbox(
label="πŸ“Š Document Status",
interactive=False,
placeholder="No document loaded yet",
show_label=True
)
gr.Markdown("---")
# Chat interface with improved layout
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
height=450,
label="πŸ—£οΈ Kantian Critique Conversation",
bubble_full_width=False,
show_label=True,
avatar_images=(
None, # User avatar
"https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/Immanuel_Kant_%28painted_portrait%29.jpg/220px-Immanuel_Kant_%28painted_portrait%29.jpg" # Kant avatar
)
)
prompt_box = gr.Textbox(
label="❓ Ask for Critique",
placeholder="e.g., 'Challenge this systematically', 'What are the fatal flaws?', 'Can this be universalized?'",
lines=2,
max_lines=4,
show_label=True
)
with gr.Row():
send_btn = gr.Button("βš”οΈ Send", variant="primary", size="lg")
clear_btn = gr.Button("🧹 Clear Chat", size="lg")
with gr.Column(scale=1):
gr.Markdown("## πŸ“Š Rate Critique Quality")
with gr.Row():
like_btn = gr.Button("πŸ‘ Strong Critique", variant="secondary", size="lg")
dislike_btn = gr.Button("πŸ‘Ž Weak Critique", variant="secondary", size="lg")
feedback_text = gr.Textbox(
label="πŸ“ Detailed Feedback (Optional)",
placeholder="How can this critique be improved? Be specific...",
lines=3,
max_lines=5,
show_label=True
)
submit_feedback_btn = gr.Button("πŸ“€ Submit Detailed Feedback", variant="primary", size="lg")
status = gr.Textbox(
label="πŸ“ˆ Feedback Status",
interactive=False,
placeholder="Your feedback will appear here...",
show_label=True
)
with gr.Accordion("ℹ️ About This AI", open=False):
gr.Markdown(
"""
**Kantian Adversarial Critic**
- Tests universalizability
- Identifies moral inconsistencies
- Challenges consequentialist reasoning
- Attacks logical fallacies
- Quotes text directly
**Personality Traits:**
🎯 Duty-focused 🌍 Universality-tester
βš–οΈ Consistency-seeker πŸ›‘οΈ Rights-defender
"""
)
# State: stores (chat_history, document_content, conversation_summary)
state = gr.State([])
doc_state = gr.State("")
summary_state = gr.State("") # Stores conversation summary
def load_document(file_path, pasted_text):
"""Load document from file or text area and create ChromaDB collection"""
content = ""
collection_id = None
if pasted_text.strip():
content = pasted_text.strip()
status_msg = f"Document loaded from text area ({len(content)} characters)"
elif file_path:
try:
# Handle different file types
if file_path.endswith('.txt') or file_path.endswith('.md'):
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
elif file_path.endswith('.pdf'):
try:
import PyPDF2
with open(file_path, 'rb') as f:
pdf_reader = PyPDF2.PdfReader(f)
content = ''
for page in pdf_reader.pages:
content += page.extract_text() + '\n'
except ImportError:
status_msg = "PDF support requires pypdf2 library. Install with: pip install pypdf2"
return "", status_msg, None
elif file_path.endswith('.docx'):
try:
from docx import Document
doc = Document(file_path)
content = '\n'.join([paragraph.text for paragraph in doc.paragraphs])
except ImportError:
status_msg = "DOCX support requires python-docx library. Install with: pip install python-docx"
return "", status_msg, None
else:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
status_msg = f"Document loaded from file ({len(content)} characters)"
except Exception as e:
status_msg = f"Error loading file: {str(e)}"
return "", status_msg, None
else:
status_msg = "Please upload a file or paste text"
return "", status_msg, None
# Create ChromaDB collection for this document
if content:
try:
collection_id = create_document_collection(content)
status_msg += f" | Collection ID: {collection_id}"
except Exception as e:
status_msg += f" | ChromaDB error: {str(e)}"
return content, status_msg, collection_id
def respond(message, history, chat_state, document, current_summary, collection_id):
"""Generate adversarial Kantian critique with no conversation history"""
print(f"\n[RESPOND] Received message: {message[:100]}...")
print(f"[RESPOND] Document present: {bool(document)}")
print(f"[RESPOND] Collection ID: {collection_id}")
if not document or not collection_id:
bot_response = "Please upload a document first so I can mount an adversarial critique."
# Always start fresh - no conversation history
new_chat_state = [(message, bot_response)]
return "", history + [[message, bot_response]], new_chat_state, "", collection_id
# Retrieve relevant document chunks using RAG
relevant_chunks = retrieve_relevant_chunks(collection_id, message, n_results=3)
context_document = "\n\n".join(relevant_chunks) if relevant_chunks else document[:2000] + ("..." if len(document) > 2000 else "")
print("[RESPOND] Starting critique generation...")
# Create context-aware prompt for adversarial critique (NO conversation history)
# ENFORCED BREVITY: Request short, focused critique
critique_prompt = f"""Document context (relevant sections):
{context_document}
User request: {message}
Provide DIRECT, FOCUSED adversarial criticism (3 sentences MAX):
1. ONE flaw
2. ONE quote
3. ONE principle
RESPOND AS KANTIAN CRITIC - USE 'I' NOT 'USER' - MAINTAIN CHARACTER - NO EXPLANATIONS - NO REPETITION - BE CONCISE"""
print("[RESPOND] Calling generate_response WITHOUT conversation history...")
# Pass EMPTY conversation history - each message is standalone
bot_response = generate_response(critique_prompt, conversation_history=[], summary="")
print(f"[RESPOND] Got response: {bot_response[:100]}...")
# Return fresh state with only this exchange
new_chat_state = [(message, bot_response)]
return "", history + [[message, bot_response]], new_chat_state, "", collection_id
def like_action(chat_state, document, current_summary):
"""Save positive feedback with document context (standalone)"""
if chat_state and document:
# Get the last (and only) exchange
if len(chat_state) > 0:
prompt, response = chat_state[-1]
# Include document snippet in feedback for better training
doc_snippet = document[:500] + "..." if len(document) > 500 else document
full_prompt = f"Document: {doc_snippet}\n\nQuestion: {prompt}"
return save_feedback(full_prompt, response, 1, ""), ""
return "No critique to rate or no document loaded.", ""
def dislike_action(chat_state, document, current_summary):
"""Save negative feedback with document context (standalone)"""
if chat_state and document:
# Get the last (and only) exchange
if len(chat_state) > 0:
prompt, response = chat_state[-1]
# Include document snippet in feedback for better training
doc_snippet = document[:500] + "..." if len(document) > 500 else document
full_prompt = f"Document: {doc_snippet}\n\nQuestion: {prompt}"
return save_feedback(full_prompt, response, 0, ""), ""
return "No critique to rate or no document loaded.", ""
def text_feedback_action(chat_state, document, current_summary, feedback_text):
"""Save detailed text feedback with document context (standalone)"""
if chat_state and document and feedback_text.strip():
# Get the last (and only) exchange
if len(chat_state) > 0:
prompt, response = chat_state[-1]
# Include document snippet in feedback for better training
doc_snippet = document[:500] + "..." if len(document) > 500 else document
full_prompt = f"Document: {doc_snippet}\n\nQuestion: {prompt}"
# Save with neutral reward but include text feedback
return save_feedback(full_prompt, response, 1, feedback_text.strip()), ""
return "No critique to rate, no document loaded, or no feedback provided.", ""
def clear_chat():
# Reset ALL state including document state
return [], [], "", "", None # Clear chat, state, summary, document, collection_id
def delete_collection(collection_id):
"""Delete the current document collection"""
if collection_id:
result = delete_document_collection(collection_id)
return result, None, [], [], "", ""
return "No collection to delete", None, [], [], "", ""
# Event handlers
upload_btn.click(load_document, [doc_upload, doc_text], [doc_state, doc_status, collection_state])
send_btn.click(respond, [prompt_box, chatbot, state, doc_state, summary_state, collection_state], [prompt_box, chatbot, state, summary_state, collection_state])
prompt_box.submit(respond, [prompt_box, chatbot, state, doc_state, summary_state, collection_state], [prompt_box, chatbot, state, summary_state, collection_state])
like_btn.click(like_action, [state, doc_state, summary_state], [status, summary_state])
dislike_btn.click(dislike_action, [state, doc_state, summary_state], [status, summary_state])
submit_feedback_btn.click(text_feedback_action, [state, doc_state, summary_state, feedback_text], [status, summary_state])
clear_btn.click(clear_chat, outputs=[chatbot, state, summary_state, doc_state, collection_state])
delete_btn.click(delete_collection, collection_state, [doc_status, collection_state, chatbot, state, summary_state, doc_state])
if __name__ == "__main__":
import threading
import time
import subprocess
# Background training worker
def background_worker():
last_count = 0
last_sync = 0
REWARD_THRESHOLD = 50
PPO_THRESHOLD = 100
CHECK_INTERVAL = 300 # 5 min
SYNC_INTERVAL = 600 # 10 min
while True:
time.sleep(CHECK_INTERVAL)
if not os.path.exists(FEEDBACK_FILE):
continue
try:
with open(FEEDBACK_FILE, "r") as f:
content = f.read().strip()
data = json.loads(content) if content else []
count = len(data)
if count > last_count:
print(f"New feedback: {count - last_count} β†’ Total: {count}")
last_count = count
# Sync to HF
if time.time() - last_sync > SYNC_INTERVAL:
try:
from data_sync import sync_to_hub
sync_to_hub()
last_sync = time.time()
except Exception as e:
print(f"Sync error: {e}")
# Train reward model
# Check if we've crossed the threshold for the first time
if count >= REWARD_THRESHOLD and last_count < REWARD_THRESHOLD:
print("\n" + "="*50)
print(f"Training reward model with {count} samples...")
print("="*50)
subprocess.run("python train_reward.py", shell=True)
print("βœ“ Reward model training complete")
print("="*50 + "\n")
# Train PPO
# Check if we've crossed the threshold for the first time
if count >= PPO_THRESHOLD and last_count < PPO_THRESHOLD:
print("\n" + "="*50)
print(f"Running PPO fine-tuning with {count} samples...")
print("="*50)
subprocess.run("python train_ppo.py", shell=True)
load_model() # Reload fine-tuned model
print("βœ“ PPO fine-tuning complete - model reloaded")
print("βœ“ New version pushed to Hugging Face")
print("="*50 + "\n")
except Exception as e:
print(f"Background worker error: {e}")
# Start background worker thread
print("Starting background training worker...")
thread = threading.Thread(target=background_worker, daemon=True)
thread.start()
# Load model and launch
load_model()
print("Launching Gradio UI with auto-training enabled...")
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)