from abc import ABC, abstractmethod from typing import Dict, Any, Optional, List import torch from transformers import AutoModel, AutoTokenizer class BaseModel(ABC): """Abstract base class for all vision-language models.""" def __init__(self, model_name: str, model_config: Dict[str, Any]): """ Initialize the base model. Args: model_name: Name of the model model_config: Configuration dictionary for the model """ self.model_name = model_name self.model_config = model_config self.model_id = model_config['model_id'] self.model = None self.tokenizer = None self.current_quantization = None self.is_loaded = False @abstractmethod def load_model(self, quantization_type: str, **kwargs) -> bool: """ Load the model with specified quantization. Args: quantization_type: Type of quantization to use **kwargs: Additional arguments for model loading Returns: True if successful, False otherwise """ pass @abstractmethod def unload_model(self) -> None: """Unload the model from memory.""" pass @abstractmethod def inference(self, image_path: str, prompt: str, **kwargs) -> str: """ Perform inference on an image with a text prompt. Args: image_path: Path to the image file prompt: Text prompt for the model **kwargs: Additional inference parameters Returns: Model's text response """ pass def is_model_loaded(self) -> bool: """Check if model is currently loaded.""" return self.is_loaded def get_model_info(self) -> Dict[str, Any]: """Get information about the model.""" return { 'name': self.model_name, 'model_id': self.model_id, 'description': self.model_config.get('description', ''), 'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0), 'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0), 'supported_quantizations': self.model_config.get('supported_quantizations', []), 'default_quantization': self.model_config.get('default_quantization', ''), 'is_loaded': self.is_loaded, 'current_quantization': self.current_quantization } def get_supported_quantizations(self) -> List[str]: """Get list of supported quantization methods.""" return self.model_config.get('supported_quantizations', []) def get_memory_requirements(self) -> Dict[str, int]: """Get memory requirements for the model.""" return { 'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0), 'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0) } def validate_quantization(self, quantization_type: str) -> bool: """ Validate if the quantization type is supported. Args: quantization_type: Quantization type to validate Returns: True if supported, False otherwise """ supported = self.get_supported_quantizations() return quantization_type in supported def __str__(self) -> str: """String representation of the model.""" status = "loaded" if self.is_loaded else "not loaded" quant = f" ({self.current_quantization})" if self.current_quantization else "" return f"{self.model_name}{quant} - {status}" def __repr__(self) -> str: """Detailed string representation.""" return f"BaseModel(name={self.model_name}, loaded={self.is_loaded}, quantization={self.current_quantization})"