Spaces:
Paused
Paused
| import torch | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class Model: | |
| """ | |
| Handles loading and managing the Kimi-VL-A3B-Thinking model and processor. | |
| Loads model and processor once and provides accessors. | |
| """ | |
| # ALL MODEL CONFIGURATIONS | |
| _model = None | |
| _processor = None | |
| MODEL_PATH = "moonshotai/Kimi-VL-A3B-Thinking-2506" | |
| model_class = AutoModelForCausalLM | |
| processor_class = AutoProcessor | |
| model_kwargs = { | |
| "device_map": "auto", | |
| "torch_dtype": "auto", | |
| "trust_remote_code": True | |
| } | |
| processor_kwargs = { | |
| "device_map": "auto", | |
| "trust_remote_code": True | |
| } | |
| def load(cls): | |
| if cls._model is None: | |
| try: | |
| cls._model = cls.model_class.from_pretrained( | |
| cls.MODEL_PATH, **cls.model_kwargs | |
| ) | |
| cls._processor = cls.processor_class.from_pretrained( | |
| cls.MODEL_PATH, **cls.processor_kwargs | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to load model or processor: {str(e)}") | |
| raise | |
| return cls._model, cls._processor | |
| def load_model(cls): | |
| model, _ = cls.load() | |
| return model | |
| def load_processor(cls): | |
| _, processor = cls.load() | |
| return processor | |
| def load_raw_model(cls): | |
| """Get the raw transformers model (not wrapped by outlines)""" | |
| cls.load() # Ensure models are loaded | |
| return cls._raw_model | |