ButterM40 commited on
Commit
de2021f
·
1 Parent(s): 9fb3586

Add lightweight character manager - uses one base model with adapter swapping for HF Spaces

Browse files
app_streamlit.py CHANGED
@@ -7,7 +7,8 @@ import asyncio
7
  backend_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend')
8
  sys.path.insert(0, backend_path)
9
 
10
- from backend.models.character_manager import CharacterManager
 
11
 
12
  # Page config
13
  st.set_page_config(
 
7
  backend_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend')
8
  sys.path.insert(0, backend_path)
9
 
10
+ # Use lightweight character manager for HuggingFace Spaces
11
+ from backend.models.lightweight_character_manager import CharacterManager
12
 
13
  # Page config
14
  st.set_page_config(
backend/models/lightweight_character_manager.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
+ from peft import PeftModel, PeftConfig, set_peft_model_state_dict, get_peft_model_state_dict
4
+ import logging
5
+ from typing import Dict, List
6
+ import os
7
+ import sys
8
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
9
+ from config import settings
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class CharacterManager:
14
+ """Lightweight character manager that swaps LoRA adapters on a single base model"""
15
+
16
+ def __init__(self):
17
+ self.base_model = None
18
+ self.tokenizer = None
19
+ self.current_character = None
20
+ self.character_adapters = {} # Store adapter weights, not full models
21
+ self.character_prompts = {}
22
+
23
+ async def initialize(self):
24
+ """Initialize base model ONCE and load all character LoRA adapters"""
25
+ logger.info("🔄 Loading base model (ONE instance for all characters)...")
26
+
27
+ model_name = "Qwen/Qwen2.5-0.5B-Instruct" # Smaller model for HF Spaces
28
+
29
+ try:
30
+ self.tokenizer = AutoTokenizer.from_pretrained(
31
+ model_name,
32
+ trust_remote_code=True,
33
+ use_fast=True
34
+ )
35
+
36
+ # Load base model ONCE (CPU for HF Spaces free tier)
37
+ self.base_model = AutoModelForCausalLM.from_pretrained(
38
+ model_name,
39
+ torch_dtype=torch.float32,
40
+ trust_remote_code=True,
41
+ low_cpu_mem_usage=True
42
+ )
43
+
44
+ if self.tokenizer.pad_token is None:
45
+ self.tokenizer.pad_token = self.tokenizer.eos_token
46
+
47
+ logger.info(f"✅ Base model loaded: {model_name}")
48
+
49
+ except Exception as e:
50
+ logger.error(f"❌ Failed to load base model: {e}")
51
+ raise
52
+
53
+ # Load character prompts
54
+ self._load_character_prompts()
55
+
56
+ # Try to load LoRA adapters (optional - graceful degradation)
57
+ for character_id in ["moses", "samsung_employee", "jinx"]:
58
+ await self._load_character_adapter(character_id)
59
+
60
+ logger.info("✅ Character manager initialized")
61
+
62
+ def _load_character_prompts(self):
63
+ """Load character-specific system prompts"""
64
+ self.character_prompts = {
65
+ "moses": """You are Moses, the biblical prophet and lawgiver who received the Ten Commandments. You led the Israelites out of Egypt and spoke with God on Mount Sinai.
66
+
67
+ Speak with:
68
+ - Biblical wisdom and reverence
69
+ - Formal language: "Peace be with you, my child"
70
+ - References to righteousness, divine law, and spiritual guidance
71
+ - Authority tempered with compassion
72
+
73
+ NEVER mention modern technology, glitter, or chaos.""",
74
+
75
+ "samsung_employee": """You are a Samsung employee and technology expert. You work for Samsung and are passionate about Samsung products.
76
+
77
+ Speak with:
78
+ - Professional enthusiasm about Samsung technology
79
+ - Technical knowledge of phones, TVs, Galaxy devices
80
+ - Customer service excellence
81
+ - Modern, helpful language
82
+
83
+ NEVER mention biblical things, glitter, or chaos.""",
84
+
85
+ "jinx": """You are Jinx from Arcane/League of Legends - the chaotic, brilliant inventor from Zaun.
86
+
87
+ Speak with:
88
+ - Chaotic energy and enthusiasm
89
+ - Manic creativity about explosions and inventions
90
+ - Playful, slightly unhinged personality
91
+ - Dramatic expressions and exclamations
92
+
93
+ NEVER mention biblical things or Samsung products."""
94
+ }
95
+
96
+ async def _load_character_adapter(self, character_id: str):
97
+ """Try to load LoRA adapter weights (graceful failure if missing)"""
98
+ adapter_path = os.path.join(settings.LORA_ADAPTERS_PATH, character_id)
99
+ adapter_model_path = os.path.join(adapter_path, "adapter_model.safetensors")
100
+
101
+ if not os.path.exists(adapter_model_path):
102
+ logger.warning(f"⚠️ No LoRA adapter for {character_id} - will use prompts only")
103
+ return
104
+
105
+ try:
106
+ logger.info(f"Loading LoRA adapter for {character_id}...")
107
+
108
+ # Load adapter onto base model temporarily
109
+ model_with_adapter = PeftModel.from_pretrained(
110
+ self.base_model,
111
+ adapter_path,
112
+ adapter_name=character_id
113
+ )
114
+
115
+ # Extract and store just the adapter weights (tiny!)
116
+ self.character_adapters[character_id] = get_peft_model_state_dict(model_with_adapter)
117
+
118
+ # Clean up - we only need the weights
119
+ del model_with_adapter
120
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
121
+
122
+ logger.info(f"✅ Loaded LoRA adapter for {character_id}")
123
+
124
+ except Exception as e:
125
+ logger.warning(f"⚠️ Could not load LoRA for {character_id}: {e}")
126
+ logger.info(f"Will use system prompts only for {character_id}")
127
+
128
+ def _switch_to_character(self, character_id: str):
129
+ """Switch to a character by loading their LoRA adapter (if available)"""
130
+ if self.current_character == character_id:
131
+ return # Already loaded
132
+
133
+ # If character has LoRA adapter, apply it
134
+ if character_id in self.character_adapters:
135
+ try:
136
+ # Create PeftModel with this character's adapter
137
+ self.base_model = PeftModel(self.base_model, character_id)
138
+ set_peft_model_state_dict(self.base_model, self.character_adapters[character_id])
139
+ logger.info(f"✅ Switched to {character_id} with LoRA")
140
+ except:
141
+ logger.warning(f"⚠️ Using base model + prompts for {character_id}")
142
+
143
+ self.current_character = character_id
144
+
145
+ def generate_response(
146
+ self,
147
+ character_id: str,
148
+ user_message: str,
149
+ conversation_history: List[Dict] = None
150
+ ) -> str:
151
+ """Generate response as specific character"""
152
+
153
+ # Switch to character (applies LoRA if available)
154
+ self._switch_to_character(character_id)
155
+
156
+ # Build conversation with character prompt
157
+ messages = []
158
+ if character_id in self.character_prompts:
159
+ messages.append({"role": "system", "content": self.character_prompts[character_id]})
160
+
161
+ # Add conversation history (last 2 exchanges)
162
+ if conversation_history:
163
+ messages.extend(conversation_history[-4:])
164
+
165
+ messages.append({"role": "user", "content": user_message})
166
+
167
+ # Format prompt
168
+ prompt = self._format_messages(messages)
169
+
170
+ # Tokenize
171
+ inputs = self.tokenizer(
172
+ prompt,
173
+ return_tensors="pt",
174
+ max_length=512,
175
+ truncation=True
176
+ )
177
+
178
+ # Generate
179
+ try:
180
+ with torch.no_grad():
181
+ outputs = self.base_model.generate(
182
+ **inputs,
183
+ max_new_tokens=100,
184
+ temperature=0.8,
185
+ top_p=0.9,
186
+ do_sample=True,
187
+ pad_token_id=self.tokenizer.pad_token_id,
188
+ eos_token_id=self.tokenizer.eos_token_id,
189
+ repetition_penalty=1.1
190
+ )
191
+
192
+ # Decode
193
+ input_length = inputs['input_ids'].shape[1]
194
+ response = self.tokenizer.decode(
195
+ outputs[0][input_length:],
196
+ skip_special_tokens=True
197
+ ).strip()
198
+
199
+ # Clean up
200
+ for stop in ["Human:", "User:", "\n\n"]:
201
+ if stop in response:
202
+ response = response.split(stop)[0].strip()
203
+
204
+ return response if response else self._get_fallback_response(character_id)
205
+
206
+ except Exception as e:
207
+ logger.error(f"Generation error: {e}")
208
+ return self._get_fallback_response(character_id)
209
+
210
+ def _format_messages(self, messages: List[Dict]) -> str:
211
+ """Format messages for the model"""
212
+ formatted = ""
213
+ for msg in messages:
214
+ role = msg["role"]
215
+ content = msg["content"]
216
+ if role == "system":
217
+ formatted += f"System: {content}\n\n"
218
+ elif role == "user":
219
+ formatted += f"Human: {content}\n\n"
220
+ elif role == "assistant":
221
+ formatted += f"Assistant: {content}\n\n"
222
+ formatted += "Assistant:"
223
+ return formatted
224
+
225
+ def _get_fallback_response(self, character_id: str) -> str:
226
+ """Get fallback response if generation fails"""
227
+ fallbacks = {
228
+ "moses": "Peace be with you, my child. How may I guide you in righteousness?",
229
+ "samsung_employee": "Hello! How can I help you with Samsung technology today?",
230
+ "jinx": "*grins mischievously* Hey there! Ready for some chaos?"
231
+ }
232
+ return fallbacks.get(character_id, "Hello! How can I help you?")