Chris4K commited on
Commit
c60cf73
·
verified ·
1 Parent(s): a7a26f6

Update llm_engine.py

Browse files
Files changed (1) hide show
  1. llm_engine.py +488 -473
llm_engine.py CHANGED
@@ -1,474 +1,489 @@
1
- # llmEngine.py
2
- # IMPROVED: Multi-provider LLM engine with CACHING to prevent reloading
3
- # This version fixes the critical issue where LocalLLM was reloading on every call
4
- # Features:
5
- # - Provider caching (models stay in memory)
6
- # - Unified OpenAI-style chat() API
7
- # - Providers: OpenAI, Anthropic, HuggingFace, Nebius, SambaNova, Local (transformers)
8
- # - Automatic fallback to local model on errors
9
- # - JSON-based credit tracking
10
-
11
- import json
12
- import os
13
- import traceback
14
- from typing import List, Dict, Optional
15
-
16
- ###########################################################
17
- # SIMPLE JSON CREDIT STORE
18
- ###########################################################
19
- CREDITS_DB_PATH = "credits.json"
20
-
21
- DEFAULT_CREDITS = {
22
- "openai": 25,
23
- "anthropic": 25000,
24
- "huggingface": 25,
25
- "nebius": 50,
26
- "modal": 250,
27
- "blaxel": 250,
28
- "elevenlabs": 44,
29
- "sambanova": 25,
30
- "local": 9999999
31
- }
32
-
33
-
34
- def load_credits():
35
- if not os.path.exists(CREDITS_DB_PATH):
36
- with open(CREDITS_DB_PATH, "w") as f:
37
- json.dump(DEFAULT_CREDITS, f)
38
- return DEFAULT_CREDITS.copy()
39
- with open(CREDITS_DB_PATH, "r") as f:
40
- return json.load(f)
41
-
42
-
43
- def save_credits(data):
44
- with open(CREDITS_DB_PATH, "w") as f:
45
- json.dump(data, f, indent=2)
46
-
47
- ###########################################################
48
- # BASE PROVIDER INTERFACE
49
- ###########################################################
50
- class BaseProvider:
51
- def chat(self, model: str, messages: List[Dict], **kwargs) -> str:
52
- raise NotImplementedError
53
-
54
- ###########################################################
55
- # PROVIDER: OPENAI
56
- ###########################################################
57
- try:
58
- from openai import OpenAI
59
- except Exception:
60
- OpenAI = None
61
-
62
- class OpenAIProvider(BaseProvider):
63
- def __init__(self):
64
- if OpenAI is None:
65
- raise RuntimeError("openai library not installed or not importable")
66
- self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY", ""))
67
-
68
- def chat(self, model, messages, **kwargs):
69
- try:
70
- from openai.types.chat import (
71
- ChatCompletionUserMessageParam,
72
- ChatCompletionAssistantMessageParam,
73
- ChatCompletionSystemMessageParam,
74
- )
75
- except Exception:
76
- ChatCompletionUserMessageParam = dict
77
- ChatCompletionAssistantMessageParam = dict
78
- ChatCompletionSystemMessageParam = dict
79
-
80
- if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
81
- raise TypeError("messages must be a list of dicts with 'role' and 'content'")
82
-
83
- safe_messages = []
84
- for m in messages:
85
- role = str(m.get("role", "user"))
86
- content = str(m.get("content", ""))
87
- if role == "user":
88
- safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content))
89
- elif role == "assistant":
90
- safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
91
- elif role == "system":
92
- safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content))
93
- else:
94
- safe_messages.append({"role": role, "content": content})
95
-
96
- response = self.client.chat.completions.create(model=model, messages=safe_messages)
97
- try:
98
- return response.choices[0].message.content
99
- except Exception:
100
- return str(response)
101
-
102
- ###########################################################
103
- # PROVIDER: ANTHROPIC
104
- ###########################################################
105
- try:
106
- from anthropic import Anthropic
107
- except Exception:
108
- Anthropic = None
109
-
110
- class AnthropicProvider(BaseProvider):
111
- def __init__(self):
112
- if Anthropic is None:
113
- raise RuntimeError("anthropic library not installed or not importable")
114
- self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY", ""))
115
-
116
- def chat(self, model, messages, **kwargs):
117
- if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
118
- raise TypeError("messages must be a list of dicts with 'role' and 'content'")
119
-
120
- user_text = "\n".join([m.get("content", "") for m in messages if m.get("role") == "user"])
121
- reply = self.client.messages.create(
122
- model=model,
123
- max_tokens=300,
124
- messages=[{"role": "user", "content": user_text}]
125
- )
126
-
127
- if hasattr(reply, "content"):
128
- content = reply.content
129
- if isinstance(content, list) and content and len(content) > 0:
130
- block = content[0]
131
- if hasattr(block, "text"):
132
- return getattr(block, "text", str(block))
133
- elif isinstance(block, dict) and "text" in block:
134
- return block["text"]
135
- else:
136
- return str(block)
137
- elif isinstance(content, str):
138
- return content
139
-
140
- if isinstance(reply, dict) and "completion" in reply:
141
- return reply["completion"]
142
- return str(reply)
143
-
144
- ###########################################################
145
- # PROVIDER: HUGGINGFACE INFERENCE API
146
- ###########################################################
147
- import requests
148
-
149
- class HuggingFaceProvider(BaseProvider):
150
- def __init__(self):
151
- self.key = os.getenv("HF_API_KEY", "")
152
-
153
- def chat(self, model, messages, **kwargs):
154
- if not messages:
155
- raise ValueError("messages is empty")
156
- text = messages[-1].get("content", "")
157
- r = requests.post(
158
- f"https://api-inference.huggingface.co/models/{model}",
159
- headers={"Authorization": f"Bearer {self.key}"} if self.key else {},
160
- json={"inputs": text},
161
- timeout=60
162
- )
163
- r.raise_for_status()
164
- out = r.json()
165
- if isinstance(out, list) and out and isinstance(out[0], dict):
166
- return out[0].get("generated_text") or str(out[0])
167
- return str(out)
168
-
169
- ###########################################################
170
- # PROVIDER: NEBIUS (OpenAI-compatible)
171
- ###########################################################
172
- class NebiusProvider(BaseProvider):
173
- def __init__(self):
174
- if OpenAI is None:
175
- raise RuntimeError("openai library not installed; Nebius wrapper expects OpenAI-compatible client")
176
- self.client = OpenAI(
177
- api_key=os.getenv("NEBIUS_API_KEY", ""),
178
- base_url=os.getenv("NEBIUS_BASE_URL", "https://api.studio.nebius.ai/v1")
179
- )
180
-
181
- def chat(self, model, messages, **kwargs):
182
- try:
183
- from openai.types.chat import (
184
- ChatCompletionUserMessageParam,
185
- ChatCompletionAssistantMessageParam,
186
- ChatCompletionSystemMessageParam,
187
- )
188
- except Exception:
189
- ChatCompletionUserMessageParam = dict
190
- ChatCompletionAssistantMessageParam = dict
191
- ChatCompletionSystemMessageParam = dict
192
-
193
- safe_messages = []
194
- for m in messages:
195
- role = str(m.get("role", "user"))
196
- content = str(m.get("content", ""))
197
- if role == "user":
198
- safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content))
199
- elif role == "assistant":
200
- safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
201
- elif role == "system":
202
- safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content))
203
- else:
204
- safe_messages.append({"role": role, "content": content})
205
-
206
- r = self.client.chat.completions.create(model=model, messages=safe_messages)
207
- try:
208
- return r.choices[0].message.content
209
- except Exception:
210
- return str(r)
211
-
212
- ###########################################################
213
- # PROVIDER: SAMBANOVA (OpenAI-compatible)
214
- ###########################################################
215
- class SambaNovaProvider(BaseProvider):
216
- def __init__(self):
217
- if OpenAI is None:
218
- raise RuntimeError("openai library not installed; SambaNova wrapper expects OpenAI-compatible client")
219
- self.client = OpenAI(
220
- api_key=os.getenv("SAMBANOVA_API_KEY", ""),
221
- base_url=os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1")
222
- )
223
-
224
- def chat(self, model, messages, **kwargs):
225
- try:
226
- from openai.types.chat import (
227
- ChatCompletionUserMessageParam,
228
- ChatCompletionAssistantMessageParam,
229
- ChatCompletionSystemMessageParam,
230
- )
231
- except Exception:
232
- ChatCompletionUserMessageParam = dict
233
- ChatCompletionAssistantMessageParam = dict
234
- ChatCompletionSystemMessageParam = dict
235
-
236
- safe_messages = []
237
- for m in messages:
238
- role = str(m.get("role", "user"))
239
- content = str(m.get("content", ""))
240
- if role == "user":
241
- safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content))
242
- elif role == "assistant":
243
- safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
244
- elif role == "system":
245
- safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content))
246
- else:
247
- safe_messages.append({"role": role, "content": content})
248
-
249
- r = self.client.chat.completions.create(model=model, messages=safe_messages)
250
- try:
251
- return r.choices[0].message.content
252
- except Exception:
253
- return str(r)
254
-
255
- ###########################################################
256
- # PROVIDER: LOCAL TRANSFORMERS (CACHED)
257
- ###########################################################
258
- try:
259
- from transformers import AutoTokenizer, AutoModelForCausalLM
260
- import torch
261
- TRANSFORMERS_AVAILABLE = True
262
- except Exception:
263
- TRANSFORMERS_AVAILABLE = False
264
-
265
- class LocalLLMProvider(BaseProvider):
266
- """
267
- Local LLM provider with caching - MODEL LOADS ONCE
268
- """
269
- def __init__(self, model_name: str = "meta-llama/Llama-3.2-3B-Instruct"):
270
- print(f"[LocalLLM] Initializing with model: {model_name}")
271
- self.model_name = os.getenv("LOCAL_MODEL", model_name)
272
- self.model = None
273
- self.tokenizer = None
274
- self.device = None
275
- self._initialize_model()
276
-
277
- def _initialize_model(self):
278
- """Initialize model ONCE - this is called only during __init__"""
279
- try:
280
- from transformers import AutoTokenizer, AutoModelForCausalLM
281
- import torch
282
-
283
- print(f"[LocalLLM] Loading model {self.model_name}...")
284
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
285
- print(f"[LocalLLM] Using device: {self.device}")
286
-
287
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
288
- if self.tokenizer.pad_token is None:
289
- self.tokenizer.pad_token = self.tokenizer.eos_token
290
-
291
- self.model = AutoModelForCausalLM.from_pretrained(
292
- self.model_name,
293
- device_map="auto" if self.device == "cuda" else None,
294
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
295
- trust_remote_code=True
296
- )
297
-
298
- print(f"[LocalLLM] Model loaded successfully!")
299
-
300
- except Exception as e:
301
- print(f"[LocalLLM] ❌ Failed to load model: {e}")
302
- self.model = None
303
- traceback.print_exc()
304
-
305
- def chat(self, model, messages, **kwargs):
306
- """
307
- Generate response - MODEL ALREADY LOADED
308
- """
309
- if self.model is None or self.tokenizer is None:
310
- return "Error: Model or tokenizer not loaded."
311
-
312
- # Extract text from messages
313
- text = messages[-1]["content"] if isinstance(messages[-1], dict) and "content" in messages[-1] else str(messages[-1])
314
-
315
- max_tokens = kwargs.get("max_tokens", 128)
316
- temperature = kwargs.get("temperature", 0.7)
317
-
318
- import torch
319
-
320
- # Tokenize
321
- inputs = self.tokenizer(
322
- text,
323
- return_tensors="pt",
324
- padding=True,
325
- truncation=True,
326
- max_length=2048
327
- ).to(self.device)
328
-
329
- # Generate (model is already loaded, just inference)
330
- with torch.no_grad():
331
- outputs = self.model.generate(
332
- **inputs,
333
- max_new_tokens=max_tokens,
334
- temperature=temperature,
335
- top_p=0.9,
336
- do_sample=temperature > 0,
337
- pad_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None,
338
- eos_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None
339
- )
340
-
341
- # Decode
342
- response = self.tokenizer.decode(
343
- outputs[0][inputs['input_ids'].shape[1]:],
344
- skip_special_tokens=True
345
- ).strip() if self.tokenizer else "Error: Tokenizer not loaded."
346
-
347
- return response
348
-
349
- ###########################################################
350
- # PROVIDER CACHE - CRITICAL FIX
351
- ###########################################################
352
- class ProviderCache:
353
- """
354
- Cache provider instances to avoid reloading models
355
- This is the KEY fix - providers are created ONCE and reused
356
- """
357
- _cache = {}
358
-
359
- @classmethod
360
- def get_provider(cls, provider_name: str) -> BaseProvider:
361
- """Get or create cached provider instance"""
362
- if provider_name not in cls._cache:
363
- print(f"[ProviderCache] Creating new instance of {provider_name}")
364
- provider_class = ProviderFactory.providers[provider_name]
365
- cls._cache[provider_name] = provider_class()
366
- else:
367
- print(f"[ProviderCache] Using cached instance of {provider_name}")
368
- return cls._cache[provider_name]
369
-
370
- @classmethod
371
- def clear_cache(cls):
372
- """Clear all cached providers (useful for debugging)"""
373
- cls._cache.clear()
374
- print("[ProviderCache] Cache cleared")
375
-
376
- ###########################################################
377
- # PROVIDER FACTORY (IMPROVED WITH CACHING)
378
- ###########################################################
379
- class ProviderFactory:
380
- providers = {
381
- "openai": OpenAIProvider,
382
- "anthropic": AnthropicProvider,
383
- "huggingface": HuggingFaceProvider,
384
- "nebius": NebiusProvider,
385
- "sambanova": SambaNovaProvider,
386
- "local": LocalLLMProvider,
387
- }
388
-
389
- @staticmethod
390
- def get(provider_name: str) -> BaseProvider:
391
- """
392
- Get provider instance - NOW USES CACHING
393
- This prevents reloading the model on every call
394
- """
395
- provider_name = provider_name.lower()
396
- if provider_name not in ProviderFactory.providers:
397
- raise ValueError(f"Unknown provider: {provider_name}")
398
-
399
- # USE CACHE instead of creating new instance every time
400
- return ProviderCache.get_provider(provider_name)
401
-
402
- ###########################################################
403
- # MAIN ENGINE WITH FALLBACK + OPENAI-STYLE API
404
- ###########################################################
405
- class LLMEngine:
406
- def __init__(self):
407
- self.credits = load_credits()
408
-
409
- def deduct(self, provider, amount):
410
- if provider not in self.credits:
411
- self.credits[provider] = 0
412
- self.credits[provider] = max(0, self.credits[provider] - amount)
413
- save_credits(self.credits)
414
-
415
- def chat(self, provider: str, model: str, messages: List[Dict], fallback: bool = True, **kwargs):
416
- """
417
- Main chat method - providers are now cached
418
- """
419
- try:
420
- p = ProviderFactory.get(provider) # This now returns cached instance
421
- result = p.chat(model=model, messages=messages, **kwargs)
422
- try:
423
- self.deduct(provider, 0.001)
424
- except Exception:
425
- pass
426
- return result
427
- except Exception as exc:
428
- print(f"⚠ Provider '{provider}' failed → fallback activated: {exc}")
429
- traceback.print_exc()
430
- if fallback:
431
- try:
432
- lp = ProviderFactory.get("local") # Gets cached local provider
433
- return lp.chat(model="local", messages=messages, **kwargs)
434
- except Exception as le:
435
- print("Fallback to local provider failed:", le)
436
- traceback.print_exc()
437
- raise
438
- raise
439
-
440
- ###########################################################
441
- # EXAMPLES + SIMPLE TESTS
442
- ###########################################################
443
- def main():
444
- engine = LLMEngine()
445
-
446
- print("=== Testing Provider Caching ===")
447
- print("\nFirst call (should load model):")
448
- result1 = engine.chat(
449
- provider="local",
450
- model="meta-llama/Llama-3.2-3B-Instruct",
451
- messages=[{"role": "user", "content": "Say hello"}]
452
- )
453
- print(f"Response: {result1[:100]}")
454
-
455
- print("\nSecond call (should use cached model - NO RELOAD):")
456
- result2 = engine.chat(
457
- provider="local",
458
- model="meta-llama/Llama-3.2-3B-Instruct",
459
- messages=[{"role": "user", "content": "Say goodbye"}]
460
- )
461
- print(f"Response: {result2[:100]}")
462
-
463
- print("\n✅ If you didn't see 'Loading model' twice, caching works!")
464
-
465
-
466
- if __name__ == "__main__":
467
- import argparse
468
- parser = argparse.ArgumentParser()
469
- parser.add_argument("--test", action="store_true", help="run examples and simple tests")
470
- args = parser.parse_args()
471
- if args.test:
472
- main()
473
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  main()
 
1
+ # llmEngine.py
2
+ # IMPROVED: Multi-provider LLM engine with CACHING to prevent reloading
3
+ # This version fixes the critical issue where LocalLLM was reloading on every call
4
+ # Features:
5
+ # - Provider caching (models stay in memory)
6
+ # - Unified OpenAI-style chat() API
7
+ # - Providers: OpenAI, Anthropic, HuggingFace, Nebius, SambaNova, Local (transformers)
8
+ # - Automatic fallback to local model on errors
9
+ # - JSON-based credit tracking
10
+
11
+ from dotenv import load_dotenv
12
+ import json
13
+ import os
14
+ import traceback
15
+ from typing import List, Dict, Optional
16
+
17
+ load_dotenv()
18
+ hf_token = os.getenv('HUGGINGFACE_TOKEN')
19
+ if hf_token:
20
+ from huggingface_hub import login
21
+ try:
22
+ login(token=hf_token)
23
+ # logger.info("[HF] Logged in")
24
+ except Exception as e:
25
+ # logger.warning(f"[HF] Login failed: {e}")
26
+ pass
27
+
28
+ ###########################################################
29
+ # SIMPLE JSON CREDIT STORE
30
+ ###########################################################
31
+ CREDITS_DB_PATH = "credits.json"
32
+
33
+ DEFAULT_CREDITS = {
34
+ "openai": 25,
35
+ "anthropic": 25000,
36
+ "huggingface": 25,
37
+ "nebius": 50,
38
+ "modal": 250,
39
+ "blaxel": 250,
40
+ "elevenlabs": 44,
41
+ "sambanova": 25,
42
+ "local": 9999999
43
+ }
44
+
45
+
46
+ def load_credits():
47
+ if not os.path.exists(CREDITS_DB_PATH):
48
+ with open(CREDITS_DB_PATH, "w") as f:
49
+ json.dump(DEFAULT_CREDITS, f)
50
+ return DEFAULT_CREDITS.copy()
51
+ with open(CREDITS_DB_PATH, "r") as f:
52
+ return json.load(f)
53
+
54
+
55
+ def save_credits(data):
56
+ with open(CREDITS_DB_PATH, "w") as f:
57
+ json.dump(data, f, indent=2)
58
+
59
+ ###########################################################
60
+ # BASE PROVIDER INTERFACE
61
+ ###########################################################
62
+ class BaseProvider:
63
+ def chat(self, model: str, messages: List[Dict], **kwargs) -> str:
64
+ raise NotImplementedError
65
+
66
+ ###########################################################
67
+ # PROVIDER: OPENAI
68
+ ###########################################################
69
+ try:
70
+ from openai import OpenAI
71
+ except Exception:
72
+ OpenAI = None
73
+
74
+ class OpenAIProvider(BaseProvider):
75
+ def __init__(self):
76
+ if OpenAI is None:
77
+ raise RuntimeError("openai library not installed or not importable")
78
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY", ""))
79
+
80
+ def chat(self, model, messages, **kwargs):
81
+ try:
82
+ from openai.types.chat import (
83
+ ChatCompletionUserMessageParam,
84
+ ChatCompletionAssistantMessageParam,
85
+ ChatCompletionSystemMessageParam,
86
+ )
87
+ except Exception:
88
+ ChatCompletionUserMessageParam = dict
89
+ ChatCompletionAssistantMessageParam = dict
90
+ ChatCompletionSystemMessageParam = dict
91
+
92
+ if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
93
+ raise TypeError("messages must be a list of dicts with 'role' and 'content'")
94
+
95
+ safe_messages = []
96
+ for m in messages:
97
+ role = str(m.get("role", "user"))
98
+ content = str(m.get("content", ""))
99
+ if role == "user":
100
+ safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content))
101
+ elif role == "assistant":
102
+ safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
103
+ elif role == "system":
104
+ safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content))
105
+ else:
106
+ safe_messages.append({"role": role, "content": content})
107
+
108
+ response = self.client.chat.completions.create(model=model, messages=safe_messages)
109
+ try:
110
+ return response.choices[0].message.content
111
+ except Exception:
112
+ return str(response)
113
+
114
+ ###########################################################
115
+ # PROVIDER: ANTHROPIC
116
+ ###########################################################
117
+ try:
118
+ from anthropic import Anthropic
119
+ except Exception:
120
+ Anthropic = None
121
+
122
+ class AnthropicProvider(BaseProvider):
123
+ def __init__(self):
124
+ if Anthropic is None:
125
+ raise RuntimeError("anthropic library not installed or not importable")
126
+ self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY", ""))
127
+
128
+ def chat(self, model, messages, **kwargs):
129
+ if not isinstance(messages, list) or not all(isinstance(m, dict) for m in messages):
130
+ raise TypeError("messages must be a list of dicts with 'role' and 'content'")
131
+
132
+ user_text = "\n".join([m.get("content", "") for m in messages if m.get("role") == "user"])
133
+ reply = self.client.messages.create(
134
+ model=model,
135
+ max_tokens=300,
136
+ messages=[{"role": "user", "content": user_text}]
137
+ )
138
+
139
+ if hasattr(reply, "content"):
140
+ content = reply.content
141
+ if isinstance(content, list) and content and len(content) > 0:
142
+ block = content[0]
143
+ if hasattr(block, "text"):
144
+ return getattr(block, "text", str(block))
145
+ elif isinstance(block, dict) and "text" in block:
146
+ return block["text"]
147
+ else:
148
+ return str(block)
149
+ elif isinstance(content, str):
150
+ return content
151
+
152
+ if isinstance(reply, dict) and "completion" in reply:
153
+ return reply["completion"]
154
+ return str(reply)
155
+
156
+ ###########################################################
157
+ # PROVIDER: HUGGINGFACE INFERENCE API
158
+ ###########################################################
159
+ import requests
160
+
161
+ class HuggingFaceProvider(BaseProvider):
162
+ def __init__(self):
163
+ self.key = os.getenv("HF_API_KEY", "")
164
+
165
+ def chat(self, model, messages, **kwargs):
166
+ if not messages:
167
+ raise ValueError("messages is empty")
168
+ text = messages[-1].get("content", "")
169
+ r = requests.post(
170
+ f"https://api-inference.huggingface.co/models/{model}",
171
+ headers={"Authorization": f"Bearer {self.key}"} if self.key else {},
172
+ json={"inputs": text},
173
+ timeout=60
174
+ )
175
+ r.raise_for_status()
176
+ out = r.json()
177
+ if isinstance(out, list) and out and isinstance(out[0], dict):
178
+ return out[0].get("generated_text") or str(out[0])
179
+ return str(out)
180
+
181
+ ###########################################################
182
+ # PROVIDER: NEBIUS (OpenAI-compatible)
183
+ ###########################################################
184
+ class NebiusProvider(BaseProvider):
185
+ def __init__(self):
186
+ if OpenAI is None:
187
+ raise RuntimeError("openai library not installed; Nebius wrapper expects OpenAI-compatible client")
188
+ self.client = OpenAI(
189
+ api_key=os.getenv("NEBIUS_API_KEY", ""),
190
+ base_url=os.getenv("NEBIUS_BASE_URL", "https://api.studio.nebius.ai/v1")
191
+ )
192
+
193
+ def chat(self, model, messages, **kwargs):
194
+ try:
195
+ from openai.types.chat import (
196
+ ChatCompletionUserMessageParam,
197
+ ChatCompletionAssistantMessageParam,
198
+ ChatCompletionSystemMessageParam,
199
+ )
200
+ except Exception:
201
+ ChatCompletionUserMessageParam = dict
202
+ ChatCompletionAssistantMessageParam = dict
203
+ ChatCompletionSystemMessageParam = dict
204
+
205
+ safe_messages = []
206
+ for m in messages:
207
+ role = str(m.get("role", "user"))
208
+ content = str(m.get("content", ""))
209
+ if role == "user":
210
+ safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content))
211
+ elif role == "assistant":
212
+ safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
213
+ elif role == "system":
214
+ safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content))
215
+ else:
216
+ safe_messages.append({"role": role, "content": content})
217
+
218
+ r = self.client.chat.completions.create(model=model, messages=safe_messages)
219
+ try:
220
+ return r.choices[0].message.content
221
+ except Exception:
222
+ return str(r)
223
+
224
+ ###########################################################
225
+ # PROVIDER: SAMBANOVA (OpenAI-compatible)
226
+ ###########################################################
227
+ class SambaNovaProvider(BaseProvider):
228
+ def __init__(self):
229
+ if OpenAI is None:
230
+ raise RuntimeError("openai library not installed; SambaNova wrapper expects OpenAI-compatible client")
231
+ self.client = OpenAI(
232
+ api_key=os.getenv("SAMBANOVA_API_KEY", ""),
233
+ base_url=os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1")
234
+ )
235
+
236
+ def chat(self, model, messages, **kwargs):
237
+ try:
238
+ from openai.types.chat import (
239
+ ChatCompletionUserMessageParam,
240
+ ChatCompletionAssistantMessageParam,
241
+ ChatCompletionSystemMessageParam,
242
+ )
243
+ except Exception:
244
+ ChatCompletionUserMessageParam = dict
245
+ ChatCompletionAssistantMessageParam = dict
246
+ ChatCompletionSystemMessageParam = dict
247
+
248
+ safe_messages = []
249
+ for m in messages:
250
+ role = str(m.get("role", "user"))
251
+ content = str(m.get("content", ""))
252
+ if role == "user":
253
+ safe_messages.append(ChatCompletionUserMessageParam(role="user", content=content))
254
+ elif role == "assistant":
255
+ safe_messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=content))
256
+ elif role == "system":
257
+ safe_messages.append(ChatCompletionSystemMessageParam(role="system", content=content))
258
+ else:
259
+ safe_messages.append({"role": role, "content": content})
260
+
261
+ r = self.client.chat.completions.create(model=model, messages=safe_messages)
262
+ try:
263
+ return r.choices[0].message.content
264
+ except Exception:
265
+ return str(r)
266
+
267
+ ###########################################################
268
+ # PROVIDER: LOCAL TRANSFORMERS (CACHED)
269
+ ###########################################################
270
+ try:
271
+ from transformers import AutoTokenizer, AutoModelForCausalLM
272
+ import torch
273
+ TRANSFORMERS_AVAILABLE = True
274
+ except Exception:
275
+ TRANSFORMERS_AVAILABLE = False
276
+
277
+ class LocalLLMProvider(BaseProvider):
278
+ """
279
+ Local LLM provider with caching - MODEL LOADS ONCE
280
+ """
281
+ def __init__(self, model_name: str = "meta-llama/Llama-3.2-3B-Instruct"):
282
+ print(f"[LocalLLM] Initializing with model: {model_name}")
283
+ self.model_name = os.getenv("LOCAL_MODEL", model_name)
284
+ self.model = None
285
+ self.tokenizer = None
286
+ self.device = None
287
+ self._initialize_model()
288
+
289
+ def _initialize_model(self):
290
+ """Initialize model ONCE - this is called only during __init__"""
291
+ try:
292
+ from transformers import AutoTokenizer, AutoModelForCausalLM
293
+ import torch
294
+
295
+
296
+
297
+
298
+ print(f"[LocalLLM] Loading model {self.model_name}...")
299
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
300
+ print(f"[LocalLLM] Using device: {self.device}")
301
+
302
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
303
+ if self.tokenizer.pad_token is None:
304
+ self.tokenizer.pad_token = self.tokenizer.eos_token
305
+
306
+ self.model = AutoModelForCausalLM.from_pretrained(
307
+ self.model_name,
308
+ device_map="auto" if self.device == "cuda" else None,
309
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
310
+ trust_remote_code=True
311
+ )
312
+
313
+ print(f"[LocalLLM] Model loaded successfully!")
314
+
315
+ except Exception as e:
316
+ print(f"[LocalLLM] Failed to load model: {e}")
317
+ self.model = None
318
+ traceback.print_exc()
319
+
320
+ def chat(self, model, messages, **kwargs):
321
+ """
322
+ Generate response - MODEL ALREADY LOADED
323
+ """
324
+ if self.model is None or self.tokenizer is None:
325
+ return "Error: Model or tokenizer not loaded."
326
+
327
+ # Extract text from messages
328
+ text = messages[-1]["content"] if isinstance(messages[-1], dict) and "content" in messages[-1] else str(messages[-1])
329
+
330
+ max_tokens = kwargs.get("max_tokens", 128)
331
+ temperature = kwargs.get("temperature", 0.7)
332
+
333
+ import torch
334
+
335
+ # Tokenize
336
+ inputs = self.tokenizer(
337
+ text,
338
+ return_tensors="pt",
339
+ padding=True,
340
+ truncation=True,
341
+ max_length=2048
342
+ ).to(self.device)
343
+
344
+ # Generate (model is already loaded, just inference)
345
+ with torch.no_grad():
346
+ outputs = self.model.generate(
347
+ **inputs,
348
+ max_new_tokens=max_tokens,
349
+ temperature=temperature,
350
+ top_p=0.9,
351
+ do_sample=temperature > 0,
352
+ pad_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None,
353
+ eos_token_id=self.tokenizer.eos_token_id if self.tokenizer and hasattr(self.tokenizer, 'eos_token_id') else None
354
+ )
355
+
356
+ # Decode
357
+ response = self.tokenizer.decode(
358
+ outputs[0][inputs['input_ids'].shape[1]:],
359
+ skip_special_tokens=True
360
+ ).strip() if self.tokenizer else "Error: Tokenizer not loaded."
361
+
362
+ return response
363
+
364
+ ###########################################################
365
+ # PROVIDER CACHE - CRITICAL FIX
366
+ ###########################################################
367
+ class ProviderCache:
368
+ """
369
+ Cache provider instances to avoid reloading models
370
+ This is the KEY fix - providers are created ONCE and reused
371
+ """
372
+ _cache = {}
373
+
374
+ @classmethod
375
+ def get_provider(cls, provider_name: str) -> BaseProvider:
376
+ """Get or create cached provider instance"""
377
+ if provider_name not in cls._cache:
378
+ print(f"[ProviderCache] Creating new instance of {provider_name}")
379
+ provider_class = ProviderFactory.providers[provider_name]
380
+ cls._cache[provider_name] = provider_class()
381
+ else:
382
+ print(f"[ProviderCache] Using cached instance of {provider_name}")
383
+ return cls._cache[provider_name]
384
+
385
+ @classmethod
386
+ def clear_cache(cls):
387
+ """Clear all cached providers (useful for debugging)"""
388
+ cls._cache.clear()
389
+ print("[ProviderCache] Cache cleared")
390
+
391
+ ###########################################################
392
+ # PROVIDER FACTORY (IMPROVED WITH CACHING)
393
+ ###########################################################
394
+ class ProviderFactory:
395
+ providers = {
396
+ "openai": OpenAIProvider,
397
+ "anthropic": AnthropicProvider,
398
+ "huggingface": HuggingFaceProvider,
399
+ "nebius": NebiusProvider,
400
+ "sambanova": SambaNovaProvider,
401
+ "local": LocalLLMProvider,
402
+ }
403
+
404
+ @staticmethod
405
+ def get(provider_name: str) -> BaseProvider:
406
+ """
407
+ Get provider instance - NOW USES CACHING
408
+ This prevents reloading the model on every call
409
+ """
410
+ provider_name = provider_name.lower()
411
+ if provider_name not in ProviderFactory.providers:
412
+ raise ValueError(f"Unknown provider: {provider_name}")
413
+
414
+ # USE CACHE instead of creating new instance every time
415
+ return ProviderCache.get_provider(provider_name)
416
+
417
+ ###########################################################
418
+ # MAIN ENGINE WITH FALLBACK + OPENAI-STYLE API
419
+ ###########################################################
420
+ class LLMEngine:
421
+ def __init__(self):
422
+ self.credits = load_credits()
423
+
424
+ def deduct(self, provider, amount):
425
+ if provider not in self.credits:
426
+ self.credits[provider] = 0
427
+ self.credits[provider] = max(0, self.credits[provider] - amount)
428
+ save_credits(self.credits)
429
+
430
+ def chat(self, provider: str, model: str, messages: List[Dict], fallback: bool = True, **kwargs):
431
+ """
432
+ Main chat method - providers are now cached
433
+ """
434
+ try:
435
+ p = ProviderFactory.get(provider) # This now returns cached instance
436
+ result = p.chat(model=model, messages=messages, **kwargs)
437
+ try:
438
+ self.deduct(provider, 0.001)
439
+ except Exception:
440
+ pass
441
+ return result
442
+ except Exception as exc:
443
+ print(f"⚠ Provider '{provider}' failed → fallback activated: {exc}")
444
+ traceback.print_exc()
445
+ if fallback:
446
+ try:
447
+ lp = ProviderFactory.get("local") # Gets cached local provider
448
+ return lp.chat(model="local", messages=messages, **kwargs)
449
+ except Exception as le:
450
+ print("Fallback to local provider failed:", le)
451
+ traceback.print_exc()
452
+ raise
453
+ raise
454
+
455
+ ###########################################################
456
+ # EXAMPLES + SIMPLE TESTS
457
+ ###########################################################
458
+ def main():
459
+ engine = LLMEngine()
460
+
461
+ print("=== Testing Provider Caching ===")
462
+ print("\nFirst call (should load model):")
463
+ result1 = engine.chat(
464
+ provider="local",
465
+ model="meta-llama/Llama-3.2-3B-Instruct",
466
+ messages=[{"role": "user", "content": "Say hello"}]
467
+ )
468
+ print(f"Response: {result1[:100]}")
469
+
470
+ print("\nSecond call (should use cached model - NO RELOAD):")
471
+ result2 = engine.chat(
472
+ provider="local",
473
+ model="meta-llama/Llama-3.2-3B-Instruct",
474
+ messages=[{"role": "user", "content": "Say goodbye"}]
475
+ )
476
+ print(f"Response: {result2[:100]}")
477
+
478
+ print("\n✅ If you didn't see 'Loading model' twice, caching works!")
479
+
480
+
481
+ if __name__ == "__main__":
482
+ import argparse
483
+ parser = argparse.ArgumentParser()
484
+ parser.add_argument("--test", action="store_true", help="run examples and simple tests")
485
+ args = parser.parse_args()
486
+ if args.test:
487
+ main()
488
+ else:
489
  main()