Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from typing import Dict, Any, Optional, List | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| class BaseModel(ABC): | |
| """Abstract base class for all vision-language models.""" | |
| def __init__(self, model_name: str, model_config: Dict[str, Any]): | |
| """ | |
| Initialize the base model. | |
| Args: | |
| model_name: Name of the model | |
| model_config: Configuration dictionary for the model | |
| """ | |
| self.model_name = model_name | |
| self.model_config = model_config | |
| self.model_id = model_config['model_id'] | |
| self.model = None | |
| self.tokenizer = None | |
| self.current_quantization = None | |
| self.is_loaded = False | |
| def load_model(self, quantization_type: str, **kwargs) -> bool: | |
| """ | |
| Load the model with specified quantization. | |
| Args: | |
| quantization_type: Type of quantization to use | |
| **kwargs: Additional arguments for model loading | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| pass | |
| def unload_model(self) -> None: | |
| """Unload the model from memory.""" | |
| pass | |
| def inference(self, image_path: str, prompt: str, **kwargs) -> str: | |
| """ | |
| Perform inference on an image with a text prompt. | |
| Args: | |
| image_path: Path to the image file | |
| prompt: Text prompt for the model | |
| **kwargs: Additional inference parameters | |
| Returns: | |
| Model's text response | |
| """ | |
| pass | |
| def is_model_loaded(self) -> bool: | |
| """Check if model is currently loaded.""" | |
| return self.is_loaded | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get information about the model.""" | |
| return { | |
| 'name': self.model_name, | |
| 'model_id': self.model_id, | |
| 'description': self.model_config.get('description', ''), | |
| 'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0), | |
| 'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0), | |
| 'supported_quantizations': self.model_config.get('supported_quantizations', []), | |
| 'default_quantization': self.model_config.get('default_quantization', ''), | |
| 'is_loaded': self.is_loaded, | |
| 'current_quantization': self.current_quantization | |
| } | |
| def get_supported_quantizations(self) -> List[str]: | |
| """Get list of supported quantization methods.""" | |
| return self.model_config.get('supported_quantizations', []) | |
| def get_memory_requirements(self) -> Dict[str, int]: | |
| """Get memory requirements for the model.""" | |
| return { | |
| 'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0), | |
| 'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0) | |
| } | |
| def validate_quantization(self, quantization_type: str) -> bool: | |
| """ | |
| Validate if the quantization type is supported. | |
| Args: | |
| quantization_type: Quantization type to validate | |
| Returns: | |
| True if supported, False otherwise | |
| """ | |
| supported = self.get_supported_quantizations() | |
| return quantization_type in supported | |
| def __str__(self) -> str: | |
| """String representation of the model.""" | |
| status = "loaded" if self.is_loaded else "not loaded" | |
| quant = f" ({self.current_quantization})" if self.current_quantization else "" | |
| return f"{self.model_name}{quant} - {status}" | |
| def __repr__(self) -> str: | |
| """Detailed string representation.""" | |
| return f"BaseModel(name={self.model_name}, loaded={self.is_loaded}, quantization={self.current_quantization})" |