RAG / utils /engine.py
Hanzo03's picture
Update utils/engine.py
b28f276 verified
import os
import shutil
import cv2
import torch
import numpy as np
import zarr
from PIL import Image
from typing import Tuple, List
from utils.config import config, get_logger
from utils.models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer
logger = get_logger("Engine")
def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
if not video_path:
return "Please upload a video.", []
if os.path.exists(config.cache_dir):
logger.info(f"Clearing old cache at {config.cache_dir}...")
shutil.rmtree(config.cache_dir, ignore_errors=True)
logger.info("Starting fast extraction process...")
vidcap = cv2.VideoCapture(video_path)
video_fps = vidcap.get(cv2.CAP_PROP_FPS)
frame_interval = max(1, int(video_fps / config.default_fps))
success, first_frame = vidcap.read()
if not success:
return "Failed to read video.", []
rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
h, w, c = rgb_first.shape
logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...")
frame_cache = zarr.create_array(
config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3
)
timestamps, count, frame_idx = [], 0, 0
while success:
if count % frame_interval == 0:
rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0)
timestamps.append(count / video_fps)
frame_idx += 1
success, first_frame = vidcap.read()
count += 1
vidcap.release()
logger.info("Generating CLIP embeddings in batches...")
all_embeddings = []
total_frames = frame_cache.shape[0]
for i in range(0, total_frames, config.batch_size):
batch_arrays = frame_cache[i : i + config.batch_size]
batch_pil = [Image.fromarray(arr) for arr in batch_arrays]
inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device)
with torch.no_grad():
# 🚨 BUGFIX: Manually extract and project the vision features
vision_outputs = clip_model.vision_model(**inputs)
features = clip_model.visual_projection(vision_outputs.pooler_output)
normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
all_embeddings.extend(normalized)
logger.info("Indexing into ChromaDB...")
ids = [f"frame_{i}" for i in range(total_frames)]
metadatas = [{"timestamp": ts, "frame_idx": i} for i, ts in enumerate(timestamps)]
global collection
chroma_client.delete_collection(config.collection_name)
collection = chroma_client.create_collection(config.collection_name)
collection.add(embeddings=all_embeddings, metadatas=metadatas, ids=ids)
sample_frames = [Image.fromarray(frame_cache[i]) for i in range(min(3, total_frames))]
return f"Processed {total_frames} frames strictly on SSD cache.", sample_frames
def ask_video_question(query: str) -> Tuple[str, List[Image.Image]]:
if collection.count() == 0:
return "Please process a video first.", []
logger.info(f"Processing query: '{query}'")
inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
# 🚨 BUGFIX: Manually extract and project the text features
text_outputs = clip_model.text_model(**inputs)
text_features = clip_model.text_projection(text_outputs.pooler_output)
text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
results = collection.query(query_embeddings=text_embedding, n_results=3)
frame_cache = zarr.open_array(config.cache_dir, mode="r")
retrieved_images = []
for metadata in results['metadatas'][0]:
img_array = frame_cache[int(metadata['frame_idx'])]
retrieved_images.append(Image.fromarray(img_array))
logger.info("Generating VLM answer...")
encoded_image = vlm_model.encode_image(retrieved_images[0])
answer = vlm_model.answer_question(encoded_image, query, vlm_tokenizer)
return answer, retrieved_images