Scropo's picture
Update app.py
e374a91 verified
import gradio as gr
import subprocess
import os
import threading
import time
# --- Configuration ---
# IMPORTANT: Change this to your Hugging Face username
YOUR_HF_USERNAME = "Scropo"
# --------------------
HUB_MODEL_ID = f"{YOUR_HF_USERNAME}/gpt-oss-20b-mentalchat-finetuned"
OUTPUT_DIR = "./gpt-oss-20b-mentalchat-finetuned"
def get_training_command():
"""Builds the accelerate launch command."""
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN secret is not set in the Space settings!")
# This command is equivalent to the one you'd run in a terminal
command = [
"accelerate", "launch",
"sft.py",
"--model_name_or_path", "openai/gpt-oss-20b",
"--dataset_name", "ShenLab/MentalChat16K",
"--output_dir", OUTPUT_DIR,
"--max_seq_length", "2048",
"--num_train_epochs", "1",
"--per_device_train_batch_size", "1",
"--gradient_accumulation_steps", "8", # Adjust based on GPU memory
"--learning_rate", "2e-5",
"--logging_steps", "5",
"--push_to_hub", "true",
"--hub_model_id", HUB_MODEL_ID,
"--hub_token", hf_token,
"--peft_lora_r", "64",
"--peft_lora_alpha", "16",
"--bf16", "true", # Use bf16 for A100/H100
"--gradient_checkpointing", "true",
]
return command
def run_training():
"""Runs the training command and streams output."""
command = get_training_command()
yield "πŸš€ Starting training job...\n"
yield f"Running command: {' '.join(command)}\n\n"
# Use Popen to start the process and capture output in real-time
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
)
# Stream the output
for line in process.stdout:
yield line
process.wait()
if process.returncode == 0:
yield f"\nπŸŽ‰ Training finished successfully! Model pushed to {HUB_MODEL_ID}"
else:
yield f"\n❌ Training failed with exit code {process.returncode}."
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# πŸš€ GPT-20B Supervised Fine-Tuning")
gr.Markdown(
"**Model:** `openai/gpt-oss-20b`\n\n"
"**Dataset:** `ShenLab/MentalChat16K`\n\n"
"Click the button below to launch the fine-tuning process. "
"This will take a very long time. Monitor the logs for progress."
)
start_button = gr.Button("Start Fine-Tuning Job")
log_output = gr.Textbox(
label="Training Logs",
lines=30,
interactive=False,
autoscroll=True
)
start_button.click(
fn=run_training,
inputs=[],
outputs=[log_output],
)
demo.queue().launch()