naturalwellness-rlhf / pipeline.py
tarnava's picture
Upload folder using huggingface_hub
6e07610 verified
# pipeline.py
# pipeline.py β€” ADD THIS AT THE VERY TOP
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Fix tokenizer warnings
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Allow CPU fallback
import multiprocessing
multiprocessing.set_start_method('spawn', force=True) # Critical for M1/M2
import threading
import time
import json
import os
import subprocess
from app import load_model
from data_sync import sync_to_hub
from dotenv import load_dotenv
load_dotenv()
FEEDBACK_FILE = "feedback.json"
REWARD_THRESHOLD = 50
PPO_THRESHOLD = 100
CHECK_INTERVAL = 300 # 5 min
SYNC_INTERVAL = 600 # 10 min
def background_worker():
last_count = 0
last_sync = 0
while True:
time.sleep(CHECK_INTERVAL)
if not os.path.exists(FEEDBACK_FILE):
continue
with open(FEEDBACK_FILE, "r") as f:
data = json.load(f)
count = len(data)
if count > last_count:
print(f"New feedback: {count - last_count} β†’ Total: {count}")
last_count = count
# Sync to HF
if time.time() - last_sync > SYNC_INTERVAL:
sync_to_hub()
last_sync = time.time()
# Train reward model
if count >= REWARD_THRESHOLD and count % REWARD_THRESHOLD == 0:
print("\n" + "="*50)
print(f"Training reward model with {count} samples...")
print("="*50)
subprocess.run("python train_reward.py", shell=True)
print("βœ“ Reward model training complete")
print("="*50 + "\n")
# Train PPO
if count >= PPO_THRESHOLD and count % PPO_THRESHOLD == 0:
print("\n" + "="*50)
print(f"Running PPO fine-tuning with {count} samples...")
print("="*50)
subprocess.run("python train_ppo.py", shell=True)
load_model() # Reload fine-tuned model in Gradio
print("βœ“ PPO fine-tuning complete - model reloaded")
print("βœ“ New version pushed to Hugging Face")
print("="*50 + "\n")
if __name__ == "__main__":
# Start background
thread = threading.Thread(target=background_worker, daemon=True)
thread.start()
# Launch Gradio
from app import demo
print("Launching Gradio UI...")
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)