naturalwellness-rlhf / reward_model_loader.py
tarnava's picture
Upload folder using huggingface_hub
6e07610 verified
raw
history blame contribute delete
450 Bytes
# reward_model_loader.py
from transformers import pipeline
import torch
def load_reward_pipeline(model_path="./reward_model"):
# Determine device
if torch.cuda.is_available():
device = 0
elif torch.backends.mps.is_available():
device = "mps"
else:
device = -1 # CPU
return pipeline(
"text-classification",
model=model_path,
return_all_scores=True,
device=device
)