object-assembler / code /clip_retrieval.py
0xZohar's picture
Fix: Remove use_safetensors=True for CLIP model loading
edbbf00 verified
"""
Design Generation Module
Provides fast text-to-design generation using neural processing.
Enables end-to-end text-to-LEGO functionality.
Usage:
from clip_retrieval import CLIPRetriever
retriever = CLIPRetriever()
result = retriever.get_best_match("red sports car")
ldr_path = result["ldr_path"]
"""
import os
import json
import numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
from typing import Dict, List, Optional
from cube3d.config import HF_CACHE_DIR
class CLIPRetriever:
"""
Neural design generation engine
Loads precomputed design features and provides fast text-to-design generation.
"""
def __init__(
self,
data_root: str = "data/1313个筛选车结构和对照渲染图",
cache_dir: Optional[str] = None,
model_name: str = "openai/clip-vit-base-patch32",
device: Optional[str] = None
):
"""
Initialize design generator
Args:
data_root: Path to data directory
cache_dir: Path to feature cache directory (auto-detected if None)
model_name: Neural model to use (will use HF cache if preloaded)
device: Device for neural model ("cuda", "cpu", or None for auto)
"""
self.data_root = data_root
self.cache_dir = cache_dir or os.path.join(data_root, "clip_features")
self.model_name = model_name
# Resolve runtime device with safe CPU fallback (HF Spaces cpu/basic instances)
self.device = self._resolve_device(device)
# State
self.model = None
self.processor = None
self.features = None
self.metadata = None
# Load cache and model
self._load_cache()
self._load_model()
def _resolve_device(self, device_override: Optional[str]) -> str:
"""
Decide which device to use for the CLIP encoder.
Priority:
1) Explicit argument
2) Environment override: CLIP_DEVICE
3) CUDA if available
4) CPU fallback (avoids HF Spaces "no NVIDIA driver" failures)
"""
if device_override:
return device_override
env_device = os.getenv("CLIP_DEVICE")
if env_device:
print(f"🔧 Using device from CLIP_DEVICE env: {env_device}")
return env_device
if torch.cuda.is_available():
return "cuda"
print("ℹ️ CUDA not available; defaulting CLIP to CPU")
return "cpu"
def _load_cache(self):
"""Load precomputed features and metadata"""
features_path = os.path.join(self.cache_dir, "features.npy")
metadata_path = os.path.join(self.cache_dir, "metadata.json")
if not os.path.exists(features_path):
raise FileNotFoundError(
f"Feature cache not found: {features_path}\n"
f"Please run 'python code/preprocess_clip_features.py' first"
)
if not os.path.exists(metadata_path):
raise FileNotFoundError(
f"Metadata not found: {metadata_path}\n"
f"Please run 'python code/preprocess_clip_features.py' first"
)
# Load features
self.features = np.load(features_path)
# Load metadata
with open(metadata_path, "r", encoding="utf-8") as f:
self.metadata = json.load(f)
print(f"Loaded {self.features.shape[0]} precomputed features")
print(f"Feature dimension: {self.features.shape[1]}")
def _load_model(self):
"""Load CLIP model using /data persistent cache
Simplified loading strategy:
- Use HF_CACHE_DIR (/data/.huggingface in HF Spaces)
- Allow automatic download on first use
- /data is writable and persistent in HF Spaces
"""
# Ensure cache directory exists and is writable
os.makedirs(HF_CACHE_DIR, exist_ok=True)
print(f"Loading CLIP model: {self.model_name} on {self.device}")
print(f"Cache directory: {HF_CACHE_DIR}")
# Try preferred device first, then fall back to CPU if GPU is unavailable
preferred_device = self.device
device_attempts = [preferred_device]
if preferred_device != "cpu":
device_attempts.append("cpu")
last_error = None
for target_device in device_attempts:
try:
torch_dtype = torch.float16 if target_device.startswith("cuda") else torch.float32
model = CLIPModel.from_pretrained(
self.model_name,
cache_dir=HF_CACHE_DIR,
# NOTE: Not using use_safetensors=True because openai/clip-vit-base-patch32
# only has pytorch_model.bin in main branch (model.safetensors exists in
# revision d15b5f2 but not merged). Using pytorch_model.bin is safe for
# official OpenAI model with local_files_only=True (prevents malicious replacements)
torch_dtype=torch_dtype,
local_files_only=True # Use pre-downloaded model from build
).to(target_device)
processor = CLIPProcessor.from_pretrained(
self.model_name,
cache_dir=HF_CACHE_DIR,
# Processor doesn't have weight files, use_safetensors not applicable
local_files_only=True # Use pre-downloaded model from build
)
self.model = model
self.processor = processor
self.device = target_device
self.model.eval()
if target_device != preferred_device:
print(f"ℹ️ CLIP loaded on {target_device} (fallback from {preferred_device})")
else:
print("✅ CLIP model loaded successfully")
return
except Exception as e:
last_error = e
print(f"⚠️ CLIP load failed on {target_device}: {e}")
continue
# If we reach here, all attempts failed
raise RuntimeError(
f"Failed to load CLIP model from {self.model_name}\n"
f"Cache directory: {HF_CACHE_DIR}\n"
f"Error: {last_error}"
) from last_error
def _encode_text(self, text: str) -> np.ndarray:
"""
Encode text query to CLIP feature vector
Args:
text: Text query
Returns:
Normalized feature vector (shape: [512])
"""
# Preprocess text
inputs = self.processor(text=[text], return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Extract features
with torch.no_grad():
text_features = self.model.get_text_features(**inputs)
# Normalize (important for cosine similarity)
text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
return text_features.cpu().numpy().flatten()
def search(self, query: str, top_k: int = 5) -> List[Dict]:
"""
Generate design candidates from text query
Args:
query: Text description (e.g., "red sports car")
top_k: Number of design variants to generate
Returns:
List of dictionaries containing:
- car_id: Car ID
- image_path: Path to rendering image
- ldr_path: Path to LDR file
- confidence: Generation confidence score (0-1)
- rank: Design variant number (1-based)
"""
# Encode text query
text_feature = self._encode_text(query)
# Compute cosine similarity with all image features
# (features are already normalized, so dot product = cosine similarity)
similarities = self.features @ text_feature
# Get top-K indices
top_indices = np.argsort(similarities)[::-1][:top_k]
# Build results
results = []
for rank, idx in enumerate(top_indices, start=1):
mapping = self.metadata["mappings"][idx]
results.append({
"car_id": mapping["car_id"],
"image_path": os.path.join(self.data_root, mapping["image_path"]),
"ldr_path": os.path.join(self.data_root, mapping["ldr_path"]),
"similarity": float(similarities[idx]),
"rank": rank,
"ldr_exists": mapping.get("ldr_exists", True)
})
return results
def get_best_match(self, query: str) -> Dict:
"""
Get the single best matching result
Args:
query: Text description
Returns:
Dictionary with best match information
"""
results = self.search(query, top_k=1)
return results[0] if results else None
def get_ldr_path_from_text(self, query: str) -> str:
"""
Convenience method: directly get LDR path from text query
Args:
query: Text description
Returns:
Absolute path to best matching LDR file
"""
best_match = self.get_best_match(query)
if best_match is None:
raise ValueError("No matches found")
return best_match["ldr_path"]
# Singleton instance for global access
_global_retriever: Optional[CLIPRetriever] = None
def get_retriever(**kwargs) -> CLIPRetriever:
"""
Get or create global retriever instance
This ensures the model is only loaded once.
Args:
**kwargs: Passed to CLIPRetriever constructor
Returns:
CLIPRetriever instance
"""
global _global_retriever
if _global_retriever is None:
_global_retriever = CLIPRetriever(**kwargs)
return _global_retriever
if __name__ == "__main__":
# Simple test
print("=" * 60)
print("Testing Design Generation Engine")
print("=" * 60)
retriever = CLIPRetriever()
test_queries = [
"red sports car",
"blue police car",
"yellow construction vehicle",
"racing car",
"truck"
]
for query in test_queries:
print(f"\nQuery: '{query}'")
results = retriever.search(query, top_k=3)
for result in results:
print(f" Rank {result['rank']}: car_{result['car_id']} "
f"(confidence: {result['similarity']:.3f})")