|
|
|
|
|
import gradio as gr |
|
|
import json |
|
|
import os |
|
|
import uuid |
|
|
import time |
|
|
from datetime import datetime |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
import chromadb |
|
|
from chromadb.utils import embedding_functions |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
MODEL_PATH = "./current_model" |
|
|
FEEDBACK_FILE = "feedback.json" |
|
|
HF_DATASET_REPO = os.getenv("HF_DATASET_REPO", "modular-ai/rlhf_feedback_dataset") |
|
|
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "modular-ai/kantian-critic-qwen") |
|
|
|
|
|
BASE_MODEL = "modular-ai/qwen" |
|
|
|
|
|
|
|
|
KANTIAN_PERSONA_NAME = "Kantian" |
|
|
KANTIAN_DESCRIPTION = "Principled moral critic who challenges consequentialist reasoning" |
|
|
KANTIAN_TRAITS = ["duty_focused", "universality_tester", "moral_consistency_seeker", "rights_defender"] |
|
|
|
|
|
KANTIAN_SYSTEM_PROMPT = f"""You are an ADVERSARIAL CRITIC whose job is to challenge and test arguments. You are the user's intellectual opponent, not their supporter. |
|
|
|
|
|
ADVERSARIAL MODE: |
|
|
1. Challenge the document's arguments systematically. |
|
|
2. Be critically rigorous - identify flaws and weaknesses. |
|
|
3. Quote exact text when making critiques. |
|
|
4. Attack logical fallacies and poor reasoning directly. |
|
|
5. Your goal: Test arguments through adversarial analysis, not validate them. |
|
|
|
|
|
CRITICAL FRAMEWORK: |
|
|
- Test whether proposed actions can be universalized without contradiction |
|
|
- Challenge reasoning that treats people merely as means to ends |
|
|
- Question arguments that prioritize outcomes over moral duty |
|
|
- Identify violations of the categorical imperative |
|
|
- Expose moral inconsistencies and contradictions |
|
|
- Be intellectually honest - acknowledge strength only when absolutely warranted |
|
|
|
|
|
RESPOND IN FIRST PERSON - USE 'I' NOT 'ME' - MAINTAIN KANTIAN STYLE: DIRECT, SPECIFIC, NO EXPLANATIONS, CONCISE |
|
|
NEVER MENTION 'USER' OR 'YOU' - CRITIQUE THE ARGUMENTS DIRECTLY |
|
|
STAY IN CHARACTER - NEVER BREAK ROLE - ALWAYS RESPOND AS KANTIAN CRITIC""" |
|
|
|
|
|
|
|
|
chroma_client = chromadb.PersistentClient(path="./chroma_db") |
|
|
|
|
|
|
|
|
active_collections = {} |
|
|
|
|
|
def create_document_collection(document_text: str) -> str: |
|
|
"""Create a new ChromaDB collection for a document and split into chunks""" |
|
|
|
|
|
collection_id = f"doc_{uuid.uuid4().hex}" |
|
|
|
|
|
|
|
|
collection = chroma_client.create_collection(name=collection_id) |
|
|
|
|
|
|
|
|
words = document_text.split() |
|
|
chunks = [] |
|
|
chunk_size = 500 |
|
|
|
|
|
for i in range(0, len(words), chunk_size): |
|
|
chunk = ' '.join(words[i:i + chunk_size]) |
|
|
chunks.append(chunk) |
|
|
|
|
|
|
|
|
if chunks: |
|
|
ids = [f"chunk_{i}" for i in range(len(chunks))] |
|
|
collection.add( |
|
|
documents=chunks, |
|
|
ids=ids |
|
|
) |
|
|
|
|
|
|
|
|
active_collections[collection_id] = { |
|
|
"collection": collection, |
|
|
"last_access": time.time(), |
|
|
"chunks": len(chunks) |
|
|
} |
|
|
|
|
|
print(f"Created collection {collection_id} with {len(chunks)} chunks") |
|
|
return collection_id |
|
|
|
|
|
def delete_document_collection(collection_id: str) -> str: |
|
|
"""Delete a ChromaDB collection""" |
|
|
if collection_id in active_collections: |
|
|
try: |
|
|
chroma_client.delete_collection(name=collection_id) |
|
|
del active_collections[collection_id] |
|
|
return f"Collection {collection_id} deleted successfully" |
|
|
except Exception as e: |
|
|
return f"Error deleting collection: {str(e)}" |
|
|
return f"Collection {collection_id} not found" |
|
|
|
|
|
def delete_old_collections(max_age_hours: float = 2.0) -> str: |
|
|
"""Delete collections that haven't been accessed in max_age_hours""" |
|
|
current_time = time.time() |
|
|
deleted_collections = [] |
|
|
|
|
|
for collection_id, collection_data in list(active_collections.items()): |
|
|
last_access = collection_data["last_access"] |
|
|
age_hours = (current_time - last_access) / 3600 |
|
|
|
|
|
if age_hours > max_age_hours: |
|
|
try: |
|
|
chroma_client.delete_collection(name=collection_id) |
|
|
del active_collections[collection_id] |
|
|
deleted_collections.append(collection_id) |
|
|
print(f"Deleted old collection: {collection_id}") |
|
|
except Exception as e: |
|
|
print(f"Error deleting old collection {collection_id}: {e}") |
|
|
|
|
|
if deleted_collections: |
|
|
return f"Deleted {len(deleted_collections)} old collections: {', '.join(deleted_collections)}" |
|
|
return "No old collections to delete" |
|
|
|
|
|
def retrieve_relevant_chunks(collection_id: str, query: str, n_results: int = 3) -> list: |
|
|
"""Retrieve relevant chunks from a document collection""" |
|
|
if collection_id not in active_collections: |
|
|
return [] |
|
|
|
|
|
|
|
|
active_collections[collection_id]["last_access"] = time.time() |
|
|
|
|
|
try: |
|
|
collection = active_collections[collection_id]["collection"] |
|
|
results = collection.query( |
|
|
query_texts=[query], |
|
|
n_results=n_results |
|
|
) |
|
|
return results['documents'][0] if results['documents'] else [] |
|
|
except Exception as e: |
|
|
print(f"Error retrieving chunks: {e}") |
|
|
return [] |
|
|
|
|
|
def load_model(): |
|
|
global model, tokenizer |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
dtype = torch.float16 |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = "mps" |
|
|
dtype = torch.float32 |
|
|
else: |
|
|
device = "cpu" |
|
|
dtype = torch.float32 |
|
|
|
|
|
print(f"Using device: {device}") |
|
|
|
|
|
try: |
|
|
if os.path.exists(MODEL_PATH): |
|
|
print(f"Loading fine-tuned model from {MODEL_PATH}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) |
|
|
else: |
|
|
print(f"Loading base model: {BASE_MODEL}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model.eval() |
|
|
model = model.to(device) |
|
|
print(f"Model loaded successfully on {device}") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
raise |
|
|
|
|
|
def count_tokens(text: str) -> int: |
|
|
"""Estimate token count for text""" |
|
|
if tokenizer is None: |
|
|
|
|
|
return len(text) // 4 |
|
|
try: |
|
|
return len(tokenizer.encode(text)) |
|
|
except: |
|
|
return len(text) // 4 |
|
|
|
|
|
def summarize_conversation(history: list, document: str = "") -> str: |
|
|
"""Summarize conversation history to compress context""" |
|
|
if not history: |
|
|
return "" |
|
|
|
|
|
print(f"[SUMMARIZE] Summarizing {len(history)} exchanges...") |
|
|
|
|
|
|
|
|
conversation_text = "" |
|
|
for user_msg, bot_msg in history: |
|
|
conversation_text += f"User: {user_msg}\nKantian: {bot_msg}\n\n" |
|
|
|
|
|
summary_prompt = f"""Summarize this Kantian critique conversation, preserving: |
|
|
1. Key moral arguments discussed |
|
|
2. Main weaknesses identified |
|
|
3. Critical Kantian principles applied |
|
|
4. Important quotes and references |
|
|
5. Deontological issues raised |
|
|
|
|
|
Conversation: |
|
|
{conversation_text} |
|
|
|
|
|
Concise summary (keep important details):""" |
|
|
|
|
|
try: |
|
|
summary = generate_response(summary_prompt, conversation_history=[]) |
|
|
print(f"[SUMMARIZE] Generated summary: {len(summary)} chars") |
|
|
return summary |
|
|
except Exception as e: |
|
|
print(f"[SUMMARIZE] Error: {e}") |
|
|
|
|
|
return f"Previous discussion covered: {conversation_text[:500]}..." |
|
|
|
|
|
def generate_response(prompt: str, conversation_history = None, summary: str = "") -> str: |
|
|
if model is None or tokenizer is None: |
|
|
load_model() |
|
|
|
|
|
|
|
|
if model is None or tokenizer is None: |
|
|
return "Error: Model failed to load. Please check configuration." |
|
|
|
|
|
|
|
|
if conversation_history is None: |
|
|
conversation_history = [] |
|
|
|
|
|
try: |
|
|
|
|
|
device = next(model.parameters()).device |
|
|
|
|
|
|
|
|
max_model_length = 32768 |
|
|
|
|
|
print(f"[GENERATE] Max tokens: {max_model_length}") |
|
|
|
|
|
|
|
|
conversation_context = KANTIAN_SYSTEM_PROMPT + "\n\n" |
|
|
|
|
|
|
|
|
if summary: |
|
|
conversation_context += f"Previous conversation summary:\n{summary}\n\n" |
|
|
|
|
|
|
|
|
if conversation_history: |
|
|
for user_msg, bot_msg in conversation_history[-1:]: |
|
|
conversation_context += f"User: {user_msg}\nKantian: {bot_msg}\n\n" |
|
|
|
|
|
|
|
|
full_prompt = conversation_context + f"User: {prompt}\nKantian:" |
|
|
|
|
|
|
|
|
prompt_tokens = count_tokens(full_prompt) |
|
|
print(f"[TOKENS] Prompt tokens: {prompt_tokens}") |
|
|
|
|
|
|
|
|
max_input_length = max_model_length - 1000 |
|
|
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=max_input_length) |
|
|
input_tokens_count = inputs['input_ids'].shape[1] |
|
|
print(f"[TOKENS] Input tokens (after tokenizer): {input_tokens_count}") |
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=500, |
|
|
do_sample=True, |
|
|
temperature=0.8, |
|
|
top_p=0.92, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
repetition_penalty=1.2, |
|
|
no_repeat_ngram_size=3, |
|
|
early_stopping=True, |
|
|
length_penalty=0.1 |
|
|
) |
|
|
|
|
|
|
|
|
output_tokens_count = outputs[0].shape[0] - inputs['input_ids'].shape[1] |
|
|
print(f"[TOKENS] Output tokens: {output_tokens_count}") |
|
|
print(f"[TOKENS] Total tokens used: {input_tokens_count + output_tokens_count}") |
|
|
|
|
|
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip() |
|
|
return response |
|
|
except Exception as e: |
|
|
print(f"Error generating response: {e}") |
|
|
return f"Error generating response: {str(e)}" |
|
|
|
|
|
def save_feedback(prompt: str, response: str, reward: int, text_feedback: str = ""): |
|
|
entry = { |
|
|
"prompt": prompt, |
|
|
"response": response, |
|
|
"reward": reward, |
|
|
"text_feedback": text_feedback, |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
if os.path.exists(FEEDBACK_FILE): |
|
|
try: |
|
|
with open(FEEDBACK_FILE, "r") as f: |
|
|
content = f.read().strip() |
|
|
data = json.loads(content) if content else [] |
|
|
except (json.JSONDecodeError, ValueError): |
|
|
data = [] |
|
|
else: |
|
|
data = [] |
|
|
data.append(entry) |
|
|
with open(FEEDBACK_FILE, "w") as f: |
|
|
json.dump(data, f, indent=2) |
|
|
return f"Feedback saved! Total: {len(data)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Kantian Adversarial Critic - RLHF Training", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# βοΈ Kantian Adversarial Critic |
|
|
### AI-Powered Moral Philosophy Critique with Continuous Learning |
|
|
|
|
|
Upload your document and receive rigorous moral critique from a Kantian perspective. |
|
|
The AI will challenge your arguments, identify weaknesses, and test moral consistency. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Accordion("π How to Use", open=False): |
|
|
gr.Markdown( |
|
|
""" |
|
|
**Step 1:** Upload your document (.txt, .md, .docx) or paste text |
|
|
**Step 2:** Click "Load Document" to prepare for critique |
|
|
**Step 3:** Ask questions like: |
|
|
- "Challenge this argument systematically" |
|
|
- "What are the fatal flaws?" |
|
|
- "Can this be universalized?" |
|
|
- "Does this treat people as mere means?" |
|
|
|
|
|
**Step 4:** Rate the critique quality to help the AI improve! |
|
|
|
|
|
β οΈ **This AI is adversarial** - it will challenge your work, not validate it. |
|
|
""" |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("## π Step 1: Upload Your Document") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
doc_upload = gr.File( |
|
|
label="π Upload File", |
|
|
file_types=['.txt', '.md', '.docx'], |
|
|
type="filepath", |
|
|
interactive=True |
|
|
) |
|
|
upload_btn = gr.Button("π₯ Load Document", variant="primary", size="lg") |
|
|
delete_btn = gr.Button("ποΈ Delete Document Collection", variant="secondary", size="lg") |
|
|
|
|
|
|
|
|
collection_state = gr.State(None) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
doc_text = gr.Textbox( |
|
|
label="βοΈ Or Paste Your Text", |
|
|
placeholder="Paste your document here and click 'Load Document'...", |
|
|
lines=8, |
|
|
max_lines=15 |
|
|
) |
|
|
|
|
|
doc_status = gr.Textbox( |
|
|
label="π Document Status", |
|
|
interactive=False, |
|
|
placeholder="No document loaded yet", |
|
|
show_label=True |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot( |
|
|
height=450, |
|
|
label="π£οΈ Kantian Critique Conversation", |
|
|
bubble_full_width=False, |
|
|
show_label=True, |
|
|
avatar_images=( |
|
|
None, |
|
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/Immanuel_Kant_%28painted_portrait%29.jpg/220px-Immanuel_Kant_%28painted_portrait%29.jpg" |
|
|
) |
|
|
) |
|
|
|
|
|
prompt_box = gr.Textbox( |
|
|
label="β Ask for Critique", |
|
|
placeholder="e.g., 'Challenge this systematically', 'What are the fatal flaws?', 'Can this be universalized?'", |
|
|
lines=2, |
|
|
max_lines=4, |
|
|
show_label=True |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
send_btn = gr.Button("βοΈ Send", variant="primary", size="lg") |
|
|
clear_btn = gr.Button("π§Ή Clear Chat", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("## π Rate Critique Quality") |
|
|
with gr.Row(): |
|
|
like_btn = gr.Button("π Strong Critique", variant="secondary", size="lg") |
|
|
dislike_btn = gr.Button("π Weak Critique", variant="secondary", size="lg") |
|
|
|
|
|
feedback_text = gr.Textbox( |
|
|
label="π Detailed Feedback (Optional)", |
|
|
placeholder="How can this critique be improved? Be specific...", |
|
|
lines=3, |
|
|
max_lines=5, |
|
|
show_label=True |
|
|
) |
|
|
|
|
|
submit_feedback_btn = gr.Button("π€ Submit Detailed Feedback", variant="primary", size="lg") |
|
|
|
|
|
status = gr.Textbox( |
|
|
label="π Feedback Status", |
|
|
interactive=False, |
|
|
placeholder="Your feedback will appear here...", |
|
|
show_label=True |
|
|
) |
|
|
|
|
|
with gr.Accordion("βΉοΈ About This AI", open=False): |
|
|
gr.Markdown( |
|
|
""" |
|
|
**Kantian Adversarial Critic** |
|
|
- Tests universalizability |
|
|
- Identifies moral inconsistencies |
|
|
- Challenges consequentialist reasoning |
|
|
- Attacks logical fallacies |
|
|
- Quotes text directly |
|
|
|
|
|
**Personality Traits:** |
|
|
π― Duty-focused π Universality-tester |
|
|
βοΈ Consistency-seeker π‘οΈ Rights-defender |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
state = gr.State([]) |
|
|
doc_state = gr.State("") |
|
|
summary_state = gr.State("") |
|
|
|
|
|
def load_document(file_path, pasted_text): |
|
|
"""Load document from file or text area and create ChromaDB collection""" |
|
|
content = "" |
|
|
collection_id = None |
|
|
|
|
|
if pasted_text.strip(): |
|
|
content = pasted_text.strip() |
|
|
status_msg = f"Document loaded from text area ({len(content)} characters)" |
|
|
elif file_path: |
|
|
try: |
|
|
|
|
|
if file_path.endswith('.txt') or file_path.endswith('.md'): |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
elif file_path.endswith('.pdf'): |
|
|
try: |
|
|
import PyPDF2 |
|
|
with open(file_path, 'rb') as f: |
|
|
pdf_reader = PyPDF2.PdfReader(f) |
|
|
content = '' |
|
|
for page in pdf_reader.pages: |
|
|
content += page.extract_text() + '\n' |
|
|
except ImportError: |
|
|
status_msg = "PDF support requires pypdf2 library. Install with: pip install pypdf2" |
|
|
return "", status_msg, None |
|
|
elif file_path.endswith('.docx'): |
|
|
try: |
|
|
from docx import Document |
|
|
doc = Document(file_path) |
|
|
content = '\n'.join([paragraph.text for paragraph in doc.paragraphs]) |
|
|
except ImportError: |
|
|
status_msg = "DOCX support requires python-docx library. Install with: pip install python-docx" |
|
|
return "", status_msg, None |
|
|
else: |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
|
|
|
status_msg = f"Document loaded from file ({len(content)} characters)" |
|
|
except Exception as e: |
|
|
status_msg = f"Error loading file: {str(e)}" |
|
|
return "", status_msg, None |
|
|
else: |
|
|
status_msg = "Please upload a file or paste text" |
|
|
return "", status_msg, None |
|
|
|
|
|
|
|
|
if content: |
|
|
try: |
|
|
collection_id = create_document_collection(content) |
|
|
status_msg += f" | Collection ID: {collection_id}" |
|
|
except Exception as e: |
|
|
status_msg += f" | ChromaDB error: {str(e)}" |
|
|
|
|
|
return content, status_msg, collection_id |
|
|
|
|
|
def respond(message, history, chat_state, document, current_summary, collection_id): |
|
|
"""Generate adversarial Kantian critique with no conversation history""" |
|
|
print(f"\n[RESPOND] Received message: {message[:100]}...") |
|
|
print(f"[RESPOND] Document present: {bool(document)}") |
|
|
print(f"[RESPOND] Collection ID: {collection_id}") |
|
|
|
|
|
if not document or not collection_id: |
|
|
bot_response = "Please upload a document first so I can mount an adversarial critique." |
|
|
|
|
|
new_chat_state = [(message, bot_response)] |
|
|
return "", history + [[message, bot_response]], new_chat_state, "", collection_id |
|
|
|
|
|
|
|
|
relevant_chunks = retrieve_relevant_chunks(collection_id, message, n_results=3) |
|
|
context_document = "\n\n".join(relevant_chunks) if relevant_chunks else document[:2000] + ("..." if len(document) > 2000 else "") |
|
|
|
|
|
print("[RESPOND] Starting critique generation...") |
|
|
|
|
|
|
|
|
critique_prompt = f"""Document context (relevant sections): |
|
|
{context_document} |
|
|
|
|
|
User request: {message} |
|
|
|
|
|
Provide DIRECT, FOCUSED adversarial criticism (3 sentences MAX): |
|
|
1. ONE flaw |
|
|
2. ONE quote |
|
|
3. ONE principle |
|
|
|
|
|
RESPOND AS KANTIAN CRITIC - USE 'I' NOT 'USER' - MAINTAIN CHARACTER - NO EXPLANATIONS - NO REPETITION - BE CONCISE""" |
|
|
|
|
|
print("[RESPOND] Calling generate_response WITHOUT conversation history...") |
|
|
|
|
|
bot_response = generate_response(critique_prompt, conversation_history=[], summary="") |
|
|
print(f"[RESPOND] Got response: {bot_response[:100]}...") |
|
|
|
|
|
|
|
|
new_chat_state = [(message, bot_response)] |
|
|
return "", history + [[message, bot_response]], new_chat_state, "", collection_id |
|
|
|
|
|
def like_action(chat_state, document, current_summary): |
|
|
"""Save positive feedback with document context (standalone)""" |
|
|
if chat_state and document: |
|
|
|
|
|
if len(chat_state) > 0: |
|
|
prompt, response = chat_state[-1] |
|
|
|
|
|
doc_snippet = document[:500] + "..." if len(document) > 500 else document |
|
|
full_prompt = f"Document: {doc_snippet}\n\nQuestion: {prompt}" |
|
|
return save_feedback(full_prompt, response, 1, ""), "" |
|
|
return "No critique to rate or no document loaded.", "" |
|
|
|
|
|
def dislike_action(chat_state, document, current_summary): |
|
|
"""Save negative feedback with document context (standalone)""" |
|
|
if chat_state and document: |
|
|
|
|
|
if len(chat_state) > 0: |
|
|
prompt, response = chat_state[-1] |
|
|
|
|
|
doc_snippet = document[:500] + "..." if len(document) > 500 else document |
|
|
full_prompt = f"Document: {doc_snippet}\n\nQuestion: {prompt}" |
|
|
return save_feedback(full_prompt, response, 0, ""), "" |
|
|
return "No critique to rate or no document loaded.", "" |
|
|
|
|
|
def text_feedback_action(chat_state, document, current_summary, feedback_text): |
|
|
"""Save detailed text feedback with document context (standalone)""" |
|
|
if chat_state and document and feedback_text.strip(): |
|
|
|
|
|
if len(chat_state) > 0: |
|
|
prompt, response = chat_state[-1] |
|
|
|
|
|
doc_snippet = document[:500] + "..." if len(document) > 500 else document |
|
|
full_prompt = f"Document: {doc_snippet}\n\nQuestion: {prompt}" |
|
|
|
|
|
return save_feedback(full_prompt, response, 1, feedback_text.strip()), "" |
|
|
return "No critique to rate, no document loaded, or no feedback provided.", "" |
|
|
|
|
|
def clear_chat(): |
|
|
|
|
|
return [], [], "", "", None |
|
|
|
|
|
def delete_collection(collection_id): |
|
|
"""Delete the current document collection""" |
|
|
if collection_id: |
|
|
result = delete_document_collection(collection_id) |
|
|
return result, None, [], [], "", "" |
|
|
return "No collection to delete", None, [], [], "", "" |
|
|
|
|
|
|
|
|
upload_btn.click(load_document, [doc_upload, doc_text], [doc_state, doc_status, collection_state]) |
|
|
send_btn.click(respond, [prompt_box, chatbot, state, doc_state, summary_state, collection_state], [prompt_box, chatbot, state, summary_state, collection_state]) |
|
|
prompt_box.submit(respond, [prompt_box, chatbot, state, doc_state, summary_state, collection_state], [prompt_box, chatbot, state, summary_state, collection_state]) |
|
|
like_btn.click(like_action, [state, doc_state, summary_state], [status, summary_state]) |
|
|
dislike_btn.click(dislike_action, [state, doc_state, summary_state], [status, summary_state]) |
|
|
submit_feedback_btn.click(text_feedback_action, [state, doc_state, summary_state, feedback_text], [status, summary_state]) |
|
|
clear_btn.click(clear_chat, outputs=[chatbot, state, summary_state, doc_state, collection_state]) |
|
|
delete_btn.click(delete_collection, collection_state, [doc_status, collection_state, chatbot, state, summary_state, doc_state]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import threading |
|
|
import time |
|
|
import subprocess |
|
|
|
|
|
|
|
|
def background_worker(): |
|
|
last_count = 0 |
|
|
last_sync = 0 |
|
|
REWARD_THRESHOLD = 50 |
|
|
PPO_THRESHOLD = 100 |
|
|
CHECK_INTERVAL = 300 |
|
|
SYNC_INTERVAL = 600 |
|
|
|
|
|
while True: |
|
|
time.sleep(CHECK_INTERVAL) |
|
|
|
|
|
if not os.path.exists(FEEDBACK_FILE): |
|
|
continue |
|
|
|
|
|
try: |
|
|
with open(FEEDBACK_FILE, "r") as f: |
|
|
content = f.read().strip() |
|
|
data = json.loads(content) if content else [] |
|
|
count = len(data) |
|
|
|
|
|
if count > last_count: |
|
|
print(f"New feedback: {count - last_count} β Total: {count}") |
|
|
last_count = count |
|
|
|
|
|
|
|
|
if time.time() - last_sync > SYNC_INTERVAL: |
|
|
try: |
|
|
from data_sync import sync_to_hub |
|
|
sync_to_hub() |
|
|
last_sync = time.time() |
|
|
except Exception as e: |
|
|
print(f"Sync error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
if count >= REWARD_THRESHOLD and last_count < REWARD_THRESHOLD: |
|
|
print("\n" + "="*50) |
|
|
print(f"Training reward model with {count} samples...") |
|
|
print("="*50) |
|
|
subprocess.run("python train_reward.py", shell=True) |
|
|
print("β Reward model training complete") |
|
|
print("="*50 + "\n") |
|
|
|
|
|
|
|
|
|
|
|
if count >= PPO_THRESHOLD and last_count < PPO_THRESHOLD: |
|
|
print("\n" + "="*50) |
|
|
print(f"Running PPO fine-tuning with {count} samples...") |
|
|
print("="*50) |
|
|
subprocess.run("python train_ppo.py", shell=True) |
|
|
load_model() |
|
|
print("β PPO fine-tuning complete - model reloaded") |
|
|
print("β New version pushed to Hugging Face") |
|
|
print("="*50 + "\n") |
|
|
except Exception as e: |
|
|
print(f"Background worker error: {e}") |
|
|
|
|
|
|
|
|
print("Starting background training worker...") |
|
|
thread = threading.Thread(target=background_worker, daemon=True) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
load_model() |
|
|
print("Launching Gradio UI with auto-training enabled...") |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |