bogota_land_space / model.py
viarias's picture
Update model.py
25fb814 verified
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)
class Model:
"""
Handles loading and managing the Kimi-VL-A3B-Thinking model and processor.
Loads model and processor once and provides accessors.
"""
# ALL MODEL CONFIGURATIONS
_model = None
_processor = None
MODEL_PATH = "moonshotai/Kimi-VL-A3B-Thinking-2506"
model_class = AutoModelForCausalLM
processor_class = AutoProcessor
model_kwargs = {
"device_map": "auto",
"torch_dtype": "auto",
"trust_remote_code": True
}
processor_kwargs = {
"device_map": "auto",
"trust_remote_code": True
}
@classmethod
def load(cls):
if cls._model is None:
try:
cls._model = cls.model_class.from_pretrained(
cls.MODEL_PATH, **cls.model_kwargs
)
cls._processor = cls.processor_class.from_pretrained(
cls.MODEL_PATH, **cls.processor_kwargs
)
except Exception as e:
logger.error(f"Failed to load model or processor: {str(e)}")
raise
return cls._model, cls._processor
@classmethod
def load_model(cls):
model, _ = cls.load()
return model
@classmethod
def load_processor(cls):
_, processor = cls.load()
return processor
@classmethod
def load_raw_model(cls):
"""Get the raw transformers model (not wrapped by outlines)"""
cls.load() # Ensure models are loaded
return cls._raw_model