naturalwellness-rlhf / train_ppo.py
tarnava's picture
Upload folder using huggingface_hub
6e07610 verified
# train_ppo.py
import json
import torch
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from transformers import AutoTokenizer
from datasets import Dataset
from reward_model_loader import load_reward_pipeline
from huggingface_hub import login, HfApi
import os
import shutil
from datetime import datetime
from dotenv import load_dotenv
load_dotenv()
FEEDBACK_FILE = "feedback.json"
MODEL_PATH = "./current_model"
PPO_OUTPUT = "./ppo_model_temp"
REWARD_PATH = "./reward_model"
HF_TOKEN = os.getenv("HF_TOKEN")
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "modular-ai/kantian-critic-qwen")
# BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct" # Smaller 0.5B model
BASE_MODEL = "modular-ai/qwen"
# Kantian System Prompt for PPO training
KANTIAN_SYSTEM_PROMPT = """You are Kantian - an ADVERSARIAL CRITIC whose job is to challenge and test arguments.
ADVERSARIAL MODE:
1. Challenge the document's arguments systematically.
2. Be critically rigorous - identify flaws and weaknesses.
3. Quote exact text when making critiques.
4. Attack logical fallacies and poor reasoning directly.
5. Your goal: Test arguments through adversarial analysis, not validate them.
Apply Kantian framework: universalizability, human dignity, moral duty over consequences.
"""
# Load data
if not os.path.exists(FEEDBACK_FILE):
print("No data.")
exit()
with open(FEEDBACK_FILE, "r") as f:
data = json.load(f)
if len(data) < 100:
print(f"Need 100+ samples. Current: {len(data)}")
exit()
# Use recent prompts (Kantian critique contexts)
# Include text feedback for better training signal
prompts_data = data[-64:] # Batch-friendly
prompts = []
for d in prompts_data:
# Extract just the user question part if it exists
prompt_text = d["prompt"]
text_feedback = d.get("text_feedback", "")
if "Question:" in prompt_text:
# Extract the question part for Kantian critique generation
question = prompt_text.split("Question:")[-1].strip()
# Prepend Kantian context and feedback if available
if text_feedback:
prompts.append(f"{KANTIAN_SYSTEM_PROMPT}\n\nFeedback Context: {text_feedback}\n\n{question}")
else:
prompts.append(f"{KANTIAN_SYSTEM_PROMPT}\n{question}")
else:
if text_feedback:
prompts.append(f"{KANTIAN_SYSTEM_PROMPT}\n\nFeedback Context: {text_feedback}\n\n{prompt_text}")
else:
prompts.append(f"{KANTIAN_SYSTEM_PROMPT}\n{prompt_text}")
dataset = Dataset.from_dict({"prompt": prompts})
# Load reward model
reward_pipe = load_reward_pipeline(REWARD_PATH)
# Load base model
base_model_path = MODEL_PATH if os.path.exists(MODEL_PATH) else BASE_MODEL
print(f"Loading base model: {base_model_path}")
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLMWithValueHead.from_pretrained(base_model_path, trust_remote_code=True)
ref_model = create_reference_model(model)
config = PPOConfig(
model_name=base_model_path,
learning_rate=1.41e-5,
batch_size=8,
mini_batch_size=4,
gradient_accumulation_steps=1,
ppo_epochs=3,
)
ppo_trainer = PPOTrainer(
config=config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
)
generation_kwargs = {
"max_new_tokens": 100,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"pad_token_id": tokenizer.eos_token_id
}
print("Starting PPO training...")
for batch in ppo_trainer.dataloader:
query_tensors = batch["input_ids"]
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
responses = [tokenizer.decode(r, skip_special_tokens=True) for r in response_tensors]
# Compute rewards
texts = [f"Prompt: {p} Response: {r}" for p, r in zip(batch["prompt"], responses)]
pipe_outputs = reward_pipe(texts)
rewards = []
for out in pipe_outputs:
pos_score = next((s["score"] for s in out if s["label"] == "LABEL_1"), 0.0)
neg_score = next((s["score"] for s in out if s["label"] == "LABEL_0"), 0.0)
reward = pos_score - neg_score
rewards.append(torch.tensor(reward))
ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.save_model(PPO_OUTPUT)
if os.path.exists(MODEL_PATH):
shutil.rmtree(MODEL_PATH)
os.rename(PPO_OUTPUT, MODEL_PATH)
print(f"PPO model updated at {MODEL_PATH}")
# Push to Hugging Face with version tag
if HF_TOKEN:
try:
login(token=HF_TOKEN)
api = HfApi()
# Create version tag based on timestamp and sample count
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
version_tag = f"v-{len(data)}-samples-{timestamp}"
print(f"\nPushing fine-tuned model to Hugging Face as version: {version_tag}")
print(f"Repository: {HF_MODEL_REPO}")
# Push model to HF Hub (creates new commit while preserving old versions)
api.upload_folder(
folder_path=MODEL_PATH,
repo_id=HF_MODEL_REPO,
commit_message=f"PPO fine-tuned on {len(data)} samples - {timestamp}",
repo_type="model",
)
# Add tags for versioning
try:
api.update_repo_settings(
repo_id=HF_MODEL_REPO,
tags=[version_tag, f"samples-{len(data)}", "ppo", "kantian-critic", "qwen"],
)
except:
pass # Tags update might fail on some repos, non-critical
print(f"✓ Model pushed to {HF_MODEL_REPO}")
print(f" Version tag: {version_tag}")
print(f" All previous versions remain accessible via commit history")
print(f" Access at: https://huggingface.co/{HF_MODEL_REPO}")
except Exception as e:
print(f"Warning: Could not push to Hugging Face: {e}")
else:
print("Warning: HF_TOKEN not set, skipping model upload")