Safe_Phi4_Full2 / handler.py
Machlovi's picture
Update handler.py
10250a3 verified
import torch
import unsloth
from transformers import AutoTokenizer, pipeline
from peft import AutoPeftModelForCausalLM
from unsloth import FastLanguageModel # FastVisionModel for LLMs
MODEL_NAME = "unsloth/Phi-4-unsloth-bnb-4bit" # Base model name (e.g., mistralai/Mistral-7B)
model_id = "Machlovi/Safe_Phi4" # Your LoRA fine-tuned adapter
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
load_in_4bit = True
def load_model():
"""Loads the base model and LoRA adapter using Unsloth."""
print("Loading base model with Unsloth...")
# Use Unsloth to load model in 4-bit efficiently
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_id,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
)
print("Creating text generation pipeline...")
text_gen_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
return text_gen_pipeline
# Load model globally so it doesn't reload on every request
pipe = load_model()
def infer(prompt: str, max_new_tokens=128):
"""Generate text using the Unsloth LoRA-adapted model."""
return pipe(prompt, max_new_tokens=max_new_tokens)[0]['generated_text']