File size: 2,358 Bytes
6e07610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# pipeline.py


# pipeline.py β€” ADD THIS AT THE VERY TOP
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Fix tokenizer warnings
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"   # Allow CPU fallback
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)  # Critical for M1/M2


import threading
import time
import json
import os
import subprocess
from app import load_model
from data_sync import sync_to_hub
from dotenv import load_dotenv

load_dotenv()

FEEDBACK_FILE = "feedback.json"
REWARD_THRESHOLD = 50
PPO_THRESHOLD = 100
CHECK_INTERVAL = 300  # 5 min
SYNC_INTERVAL = 600  # 10 min

def background_worker():
    last_count = 0
    last_sync = 0

    while True:
        time.sleep(CHECK_INTERVAL)

        if not os.path.exists(FEEDBACK_FILE):
            continue

        with open(FEEDBACK_FILE, "r") as f:
            data = json.load(f)
        count = len(data)

        if count > last_count:
            print(f"New feedback: {count - last_count} β†’ Total: {count}")
            last_count = count

        # Sync to HF
        if time.time() - last_sync > SYNC_INTERVAL:
            sync_to_hub()
            last_sync = time.time()

        # Train reward model
        if count >= REWARD_THRESHOLD and count % REWARD_THRESHOLD == 0:
            print("\n" + "="*50)
            print(f"Training reward model with {count} samples...")
            print("="*50)
            subprocess.run("python train_reward.py", shell=True)
            print("βœ“ Reward model training complete")
            print("="*50 + "\n")

        # Train PPO
        if count >= PPO_THRESHOLD and count % PPO_THRESHOLD == 0:
            print("\n" + "="*50)
            print(f"Running PPO fine-tuning with {count} samples...")
            print("="*50)
            subprocess.run("python train_ppo.py", shell=True)
            load_model()  # Reload fine-tuned model in Gradio
            print("βœ“ PPO fine-tuning complete - model reloaded")
            print("βœ“ New version pushed to Hugging Face")
            print("="*50 + "\n")

if __name__ == "__main__":
    # Start background
    thread = threading.Thread(target=background_worker, daemon=True)
    thread.start()

    # Launch Gradio
    from app import demo
    print("Launching Gradio UI...")
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)