# train_ppo.py import json import torch from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model from transformers import AutoTokenizer from datasets import Dataset from reward_model_loader import load_reward_pipeline from huggingface_hub import login, HfApi import os import shutil from datetime import datetime from dotenv import load_dotenv load_dotenv() FEEDBACK_FILE = "feedback.json" MODEL_PATH = "./current_model" PPO_OUTPUT = "./ppo_model_temp" REWARD_PATH = "./reward_model" HF_TOKEN = os.getenv("HF_TOKEN") HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "modular-ai/kantian-critic-qwen") # BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct" # Smaller 0.5B model BASE_MODEL = "modular-ai/qwen" # Kantian System Prompt for PPO training KANTIAN_SYSTEM_PROMPT = """You are Kantian - an ADVERSARIAL CRITIC whose job is to challenge and test arguments. ADVERSARIAL MODE: 1. Challenge the document's arguments systematically. 2. Be critically rigorous - identify flaws and weaknesses. 3. Quote exact text when making critiques. 4. Attack logical fallacies and poor reasoning directly. 5. Your goal: Test arguments through adversarial analysis, not validate them. Apply Kantian framework: universalizability, human dignity, moral duty over consequences. """ # Load data if not os.path.exists(FEEDBACK_FILE): print("No data.") exit() with open(FEEDBACK_FILE, "r") as f: data = json.load(f) if len(data) < 100: print(f"Need 100+ samples. Current: {len(data)}") exit() # Use recent prompts (Kantian critique contexts) # Include text feedback for better training signal prompts_data = data[-64:] # Batch-friendly prompts = [] for d in prompts_data: # Extract just the user question part if it exists prompt_text = d["prompt"] text_feedback = d.get("text_feedback", "") if "Question:" in prompt_text: # Extract the question part for Kantian critique generation question = prompt_text.split("Question:")[-1].strip() # Prepend Kantian context and feedback if available if text_feedback: prompts.append(f"{KANTIAN_SYSTEM_PROMPT}\n\nFeedback Context: {text_feedback}\n\n{question}") else: prompts.append(f"{KANTIAN_SYSTEM_PROMPT}\n{question}") else: if text_feedback: prompts.append(f"{KANTIAN_SYSTEM_PROMPT}\n\nFeedback Context: {text_feedback}\n\n{prompt_text}") else: prompts.append(f"{KANTIAN_SYSTEM_PROMPT}\n{prompt_text}") dataset = Dataset.from_dict({"prompt": prompts}) # Load reward model reward_pipe = load_reward_pipeline(REWARD_PATH) # Load base model base_model_path = MODEL_PATH if os.path.exists(MODEL_PATH) else BASE_MODEL print(f"Loading base model: {base_model_path}") tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLMWithValueHead.from_pretrained(base_model_path, trust_remote_code=True) ref_model = create_reference_model(model) config = PPOConfig( model_name=base_model_path, learning_rate=1.41e-5, batch_size=8, mini_batch_size=4, gradient_accumulation_steps=1, ppo_epochs=3, ) ppo_trainer = PPOTrainer( config=config, model=model, ref_model=ref_model, tokenizer=tokenizer, dataset=dataset, ) generation_kwargs = { "max_new_tokens": 100, "do_sample": True, "temperature": 0.7, "top_p": 0.9, "pad_token_id": tokenizer.eos_token_id } print("Starting PPO training...") for batch in ppo_trainer.dataloader: query_tensors = batch["input_ids"] response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs) responses = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors] # Compute rewards texts = [f"Prompt: {p} Response: {r}" for p, r in zip(batch["prompt"], responses)] pipe_outputs = reward_pipe(texts) rewards = [] for out in pipe_outputs: pos_score = next((s["score"] for s in out if s["label"] == "LABEL_1"), 0.0) neg_score = next((s["score"] for s in out if s["label"] == "LABEL_0"), 0.0) reward = pos_score - neg_score rewards.append(torch.tensor(reward)) ppo_trainer.step(query_tensors, response_tensors, rewards) ppo_trainer.save_model(PPO_OUTPUT) if os.path.exists(MODEL_PATH): shutil.rmtree(MODEL_PATH) os.rename(PPO_OUTPUT, MODEL_PATH) print(f"PPO model updated at {MODEL_PATH}") # Push to Hugging Face with version tag if HF_TOKEN: try: login(token=HF_TOKEN) api = HfApi() # Create version tag based on timestamp and sample count timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") version_tag = f"v-{len(data)}-samples-{timestamp}" print(f"\nPushing fine-tuned model to Hugging Face as version: {version_tag}") print(f"Repository: {HF_MODEL_REPO}") # Push model to HF Hub (creates new commit while preserving old versions) api.upload_folder( folder_path=MODEL_PATH, repo_id=HF_MODEL_REPO, commit_message=f"PPO fine-tuned on {len(data)} samples - {timestamp}", repo_type="model", ) # Add tags for versioning try: api.update_repo_settings( repo_id=HF_MODEL_REPO, tags=[version_tag, f"samples-{len(data)}", "ppo", "kantian-critic", "qwen"], ) except: pass # Tags update might fail on some repos, non-critical print(f"✓ Model pushed to {HF_MODEL_REPO}") print(f" Version tag: {version_tag}") print(f" All previous versions remain accessible via commit history") print(f" Access at: https://huggingface.co/{HF_MODEL_REPO}") except Exception as e: print(f"Warning: Could not push to Hugging Face: {e}") else: print("Warning: HF_TOKEN not set, skipping model upload")