|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
|
|
import multiprocessing |
|
|
multiprocessing.set_start_method('spawn', force=True) |
|
|
|
|
|
|
|
|
import threading |
|
|
import time |
|
|
import json |
|
|
import os |
|
|
import subprocess |
|
|
from app import load_model |
|
|
from data_sync import sync_to_hub |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
FEEDBACK_FILE = "feedback.json" |
|
|
REWARD_THRESHOLD = 50 |
|
|
PPO_THRESHOLD = 100 |
|
|
CHECK_INTERVAL = 300 |
|
|
SYNC_INTERVAL = 600 |
|
|
|
|
|
def background_worker(): |
|
|
last_count = 0 |
|
|
last_sync = 0 |
|
|
|
|
|
while True: |
|
|
time.sleep(CHECK_INTERVAL) |
|
|
|
|
|
if not os.path.exists(FEEDBACK_FILE): |
|
|
continue |
|
|
|
|
|
with open(FEEDBACK_FILE, "r") as f: |
|
|
data = json.load(f) |
|
|
count = len(data) |
|
|
|
|
|
if count > last_count: |
|
|
print(f"New feedback: {count - last_count} β Total: {count}") |
|
|
last_count = count |
|
|
|
|
|
|
|
|
if time.time() - last_sync > SYNC_INTERVAL: |
|
|
sync_to_hub() |
|
|
last_sync = time.time() |
|
|
|
|
|
|
|
|
if count >= REWARD_THRESHOLD and count % REWARD_THRESHOLD == 0: |
|
|
print("\n" + "="*50) |
|
|
print(f"Training reward model with {count} samples...") |
|
|
print("="*50) |
|
|
subprocess.run("python train_reward.py", shell=True) |
|
|
print("β Reward model training complete") |
|
|
print("="*50 + "\n") |
|
|
|
|
|
|
|
|
if count >= PPO_THRESHOLD and count % PPO_THRESHOLD == 0: |
|
|
print("\n" + "="*50) |
|
|
print(f"Running PPO fine-tuning with {count} samples...") |
|
|
print("="*50) |
|
|
subprocess.run("python train_ppo.py", shell=True) |
|
|
load_model() |
|
|
print("β PPO fine-tuning complete - model reloaded") |
|
|
print("β New version pushed to Hugging Face") |
|
|
print("="*50 + "\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
thread = threading.Thread(target=background_worker, daemon=True) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
from app import demo |
|
|
print("Launching Gradio UI...") |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |