Test-Prompt / backend /models /base_model.py
abhiman181025's picture
First commit
1314bf5
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
import torch
from transformers import AutoModel, AutoTokenizer
class BaseModel(ABC):
"""Abstract base class for all vision-language models."""
def __init__(self, model_name: str, model_config: Dict[str, Any]):
"""
Initialize the base model.
Args:
model_name: Name of the model
model_config: Configuration dictionary for the model
"""
self.model_name = model_name
self.model_config = model_config
self.model_id = model_config['model_id']
self.model = None
self.tokenizer = None
self.current_quantization = None
self.is_loaded = False
@abstractmethod
def load_model(self, quantization_type: str, **kwargs) -> bool:
"""
Load the model with specified quantization.
Args:
quantization_type: Type of quantization to use
**kwargs: Additional arguments for model loading
Returns:
True if successful, False otherwise
"""
pass
@abstractmethod
def unload_model(self) -> None:
"""Unload the model from memory."""
pass
@abstractmethod
def inference(self, image_path: str, prompt: str, **kwargs) -> str:
"""
Perform inference on an image with a text prompt.
Args:
image_path: Path to the image file
prompt: Text prompt for the model
**kwargs: Additional inference parameters
Returns:
Model's text response
"""
pass
def is_model_loaded(self) -> bool:
"""Check if model is currently loaded."""
return self.is_loaded
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the model."""
return {
'name': self.model_name,
'model_id': self.model_id,
'description': self.model_config.get('description', ''),
'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0),
'supported_quantizations': self.model_config.get('supported_quantizations', []),
'default_quantization': self.model_config.get('default_quantization', ''),
'is_loaded': self.is_loaded,
'current_quantization': self.current_quantization
}
def get_supported_quantizations(self) -> List[str]:
"""Get list of supported quantization methods."""
return self.model_config.get('supported_quantizations', [])
def get_memory_requirements(self) -> Dict[str, int]:
"""Get memory requirements for the model."""
return {
'min_gpu_memory_gb': self.model_config.get('min_gpu_memory_gb', 0),
'recommended_gpu_memory_gb': self.model_config.get('recommended_gpu_memory_gb', 0)
}
def validate_quantization(self, quantization_type: str) -> bool:
"""
Validate if the quantization type is supported.
Args:
quantization_type: Quantization type to validate
Returns:
True if supported, False otherwise
"""
supported = self.get_supported_quantizations()
return quantization_type in supported
def __str__(self) -> str:
"""String representation of the model."""
status = "loaded" if self.is_loaded else "not loaded"
quant = f" ({self.current_quantization})" if self.current_quantization else ""
return f"{self.model_name}{quant} - {status}"
def __repr__(self) -> str:
"""Detailed string representation."""
return f"BaseModel(name={self.model_name}, loaded={self.is_loaded}, quantization={self.current_quantization})"