Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import threading | |
| import time | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Union, Any, Optional, Callable | |
| import gradio as gr | |
| from ..models.model_manager import ModelManager | |
| from ..utils.data_processing import extract_file_dict, validate_data, extract_binary_output | |
| from ..config.config_manager import ConfigManager | |
| from ..utils.metrics import create_accuracy_table | |
| from datetime import datetime | |
| import boto3 | |
| class InferenceEngine: | |
| """Engine for handling batch inference and processing control.""" | |
| def __init__(self, model_manager: ModelManager, config_manager: ConfigManager): | |
| """ | |
| Initialize the inference engine. | |
| Args: | |
| model_manager: Model manager instance | |
| config_manager: Configuration manager instance | |
| """ | |
| self.model_manager = model_manager | |
| self.config_manager = config_manager | |
| self.processing_lock = threading.Lock() | |
| self.stop_processing = False | |
| self.full_df = None # Store full dataframe with image paths | |
| def set_stop_flag(self) -> str: | |
| """Set the global stop flag to interrupt processing.""" | |
| with self.processing_lock: | |
| self.stop_processing = True | |
| print("π Stop signal received. Processing will halt after current image...") | |
| return "π Stopping process... Please wait for current image to complete." | |
| def reset_stop_flag(self) -> None: | |
| """Reset the global stop flag before starting new processing.""" | |
| with self.processing_lock: | |
| self.stop_processing = False | |
| def check_stop_flag(self) -> bool: | |
| """Check if processing should be stopped.""" | |
| with self.processing_lock: | |
| return self.stop_processing | |
| def _should_load_model(self, model_selection: str, quantization_type: str) -> bool: | |
| """ | |
| Check if we need to load the model. | |
| Args: | |
| model_selection: Selected model name | |
| quantization_type: Selected quantization type | |
| Returns: | |
| True if model needs to be loaded, False otherwise | |
| """ | |
| # If no model is loaded, we need to load | |
| if not self.model_manager.current_model or not self.model_manager.current_model.is_model_loaded(): | |
| return True | |
| # If different model is selected, we need to load | |
| if self.model_manager.current_model_name != model_selection: | |
| return True | |
| # If same model but different quantization, we need to reload | |
| if self.model_manager.current_model.current_quantization != quantization_type: | |
| return True | |
| return False | |
| def _ensure_correct_model_loaded(self, model_selection: str, quantization_type: str, progress: gr.Progress()) -> None: | |
| """ | |
| Ensure the correct model with correct quantization is loaded. | |
| Args: | |
| model_selection: Selected model name | |
| quantization_type: Selected quantization type | |
| progress: Gradio progress object | |
| """ | |
| if self._should_load_model(model_selection, quantization_type): | |
| progress(0, desc=f"π Loading {model_selection} ({quantization_type})...") | |
| print(f"π Loading {model_selection} with {quantization_type}...") | |
| success = self.model_manager.load_model(model_selection, quantization_type) | |
| if not success: | |
| raise Exception(f"Failed to load model {model_selection} with {quantization_type}") | |
| else: | |
| print(f"β Correct model already loaded: {model_selection} with {quantization_type}") | |
| def process_folder_input( | |
| self, | |
| folder_path: List[Path], | |
| prompt: str, | |
| quantization_type: str, | |
| model_selection: str, | |
| progress: gr.Progress() | |
| ) -> Tuple[Any, ...]: | |
| """ | |
| Process input folder with images and optional CSV. | |
| Args: | |
| folder_path: List of Path objects from Gradio | |
| prompt: Text prompt for inference | |
| quantization_type: Model quantization type | |
| model_selection: Selected model name | |
| progress: Gradio progress object | |
| Returns: | |
| Tuple of UI update states and results | |
| """ | |
| # Reset stop flag at the beginning of processing | |
| self.reset_stop_flag() | |
| # Extract file dictionary | |
| file_dict = extract_file_dict(folder_path) | |
| # Print all file names for debug | |
| for fname in file_dict: | |
| print(fname) | |
| validation_result, message = validate_data(file_dict) | |
| # Handle different validation results | |
| if validation_result == False: | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), message, gr.update(visible=False), "" | |
| elif validation_result in ["no_csv", "multiple_csv"]: | |
| return self._process_without_csv(file_dict, prompt, quantization_type, model_selection, progress) | |
| else: | |
| return self._process_with_csv(file_dict, prompt, quantization_type, model_selection, progress) | |
| def _process_without_csv( | |
| self, | |
| file_dict: Dict[str, Path], | |
| prompt: str, | |
| quantization_type: str, | |
| model_selection: str, | |
| progress: gr.Progress() | |
| ) -> Tuple[Any, ...]: | |
| """Process images without CSV file.""" | |
| image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'] | |
| image_file_dict = {fname: file_dict[fname] for fname in file_dict | |
| if any(fname.lower().endswith(ext) for ext in image_exts)} | |
| filtered_rows = [] | |
| total_images = len(image_file_dict) | |
| if total_images == 0: | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), "No image files found.", gr.update(visible=False), "" | |
| # Ensure correct model is loaded | |
| self._ensure_correct_model_loaded(model_selection, quantization_type, progress) | |
| # Initialize progress | |
| progress(0, desc=f"π Starting to process {total_images} images...") | |
| print(f"Starting to process {total_images} images with {model_selection}...") | |
| for idx, (img_name, img_path) in enumerate(image_file_dict.items()): | |
| # Check stop flag before processing each image | |
| if self.check_stop_flag(): | |
| print(f"π Processing stopped by user at image {idx + 1}/{total_images}") | |
| # Add remaining images as "Not processed" entries | |
| for remaining_idx, (remaining_name, remaining_path) in enumerate(list(image_file_dict.items())[idx:]): | |
| filtered_rows.append({ | |
| 'S.No': idx + remaining_idx + 1, | |
| 'Image Name': remaining_name, | |
| 'Ground Truth': '', | |
| 'Binary Output': 'Not processed (stopped)', | |
| 'Model Output': 'Processing stopped by user', | |
| 'Image Path': str(remaining_path) | |
| }) | |
| display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']] | |
| self.full_df = pd.DataFrame(filtered_rows) | |
| final_message = f"π Processing stopped by user. Completed {idx}/{total_images} images." | |
| print(final_message) | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message | |
| try: | |
| # Update progress with current image info | |
| current_progress = idx / total_images | |
| progress_msg = f"π Processing image {idx + 1}/{total_images}: {img_name[:30]}..." if len(img_name) > 30 else f"π Processing image {idx + 1}/{total_images}: {img_name}" | |
| progress(current_progress, desc=progress_msg) | |
| print(progress_msg) | |
| # Use model inference | |
| model_output = self.model_manager.inference(str(img_path), prompt) if prompt else "No prompt provided" | |
| # Extract binary output (no ground truth available for file-based processing) | |
| binary_output = extract_binary_output(model_output, "", []) | |
| filtered_rows.append({ | |
| 'S.No': idx + 1, | |
| 'Image Name': img_name, | |
| 'Ground Truth': '', # Empty for manual input | |
| 'Binary Output': binary_output, | |
| 'Model Output': model_output, | |
| 'Image Path': str(img_path) | |
| }) | |
| # Update progress after successful processing | |
| current_progress = (idx + 1) / total_images | |
| progress_msg = f"β Completed {idx + 1}/{total_images} images" | |
| progress(current_progress, desc=progress_msg) | |
| print(f"Successfully processed image {idx + 1} of {total_images}") | |
| except Exception as e: | |
| print(f"Error processing image {idx + 1} of {total_images}: {str(e)}") | |
| filtered_rows.append({ | |
| 'S.No': idx + 1, | |
| 'Image Name': img_name, | |
| 'Ground Truth': '', | |
| 'Binary Output': 'Enter the output manually', # Default for errors | |
| 'Model Output': f"Error: {str(e)}", | |
| 'Image Path': str(img_path) | |
| }) | |
| # Update progress even for errors | |
| current_progress = (idx + 1) / total_images | |
| progress_msg = f"β οΈ Processed {idx + 1}/{total_images} images (with errors)" | |
| progress(current_progress, desc=progress_msg) | |
| # Check if processing was completed or stopped | |
| if self.check_stop_flag(): | |
| final_message = f"π Processing stopped by user. Completed {len(filtered_rows)}/{total_images} images." | |
| else: | |
| final_message = f"π Successfully completed processing all {total_images} images!" | |
| display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']] | |
| # Save the full dataframe (with Image Path) for preview | |
| self.full_df = pd.DataFrame(filtered_rows) | |
| self.save_results_to_s3(display_df) | |
| print(final_message) | |
| # Make the table editable for ground truth input | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message | |
| def _process_with_csv( | |
| self, | |
| file_dict: Dict[str, Path], | |
| prompt: str, | |
| quantization_type: str, | |
| model_selection: str, | |
| progress: gr.Progress() | |
| ) -> Tuple[Any, ...]: | |
| """Process images with CSV file.""" | |
| csv_files = [fname for fname in file_dict if fname.lower().endswith('.csv')] | |
| csv_file = file_dict[csv_files[0]] | |
| df = pd.read_csv(csv_file) | |
| # Collect all ground truth values for unique keyword extraction | |
| all_ground_truths = [str(row['Ground Truth']) for idx, row in df.iterrows() | |
| if pd.notna(row['Ground Truth']) and str(row['Ground Truth']).strip()] | |
| # Find image files | |
| image_exts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'] | |
| image_file_dict = {fname: file_dict[fname] for fname in file_dict | |
| if any(fname.lower().endswith(ext) for ext in image_exts)} | |
| # Only keep rows where image file exists | |
| filtered_rows = [] | |
| matching_images = [row for idx, row in df.iterrows() if row['Image Name'] in image_file_dict] | |
| total_images = len(matching_images) | |
| if total_images == 0: | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), "No matching images found for entries in CSV.", gr.update(visible=False), "" | |
| # Ensure correct model is loaded | |
| self._ensure_correct_model_loaded(model_selection, quantization_type, progress) | |
| # Initialize progress | |
| progress(0, desc=f"π Starting to process {total_images} images...") | |
| print(f"Starting to process {total_images} images with {model_selection}...") | |
| processed_count = 0 | |
| for idx, row in df.iterrows(): | |
| img_name = row['Image Name'] | |
| if img_name in image_file_dict: | |
| # Check stop flag before processing each image | |
| if self.check_stop_flag(): | |
| print(f"π Processing stopped by user at image {processed_count + 1}/{total_images}") | |
| # Add remaining unprocessed images | |
| for remaining_idx, remaining_row in df.iloc[idx:].iterrows(): | |
| if remaining_row['Image Name'] in image_file_dict: | |
| filtered_rows.append({ | |
| 'S.No': len(filtered_rows) + 1, | |
| 'Image Name': remaining_row['Image Name'], | |
| 'Ground Truth': remaining_row['Ground Truth'], | |
| 'Binary Output': 'Not processed (stopped)', | |
| 'Model Output': 'Processing stopped by user', | |
| 'Image Path': str(image_file_dict[remaining_row['Image Name']]) | |
| }) | |
| display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']] | |
| self.full_df = pd.DataFrame(filtered_rows) | |
| final_message = f"π Processing stopped by user. Completed {processed_count}/{total_images} images." | |
| print(final_message) | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message | |
| try: | |
| processed_count += 1 | |
| # Update progress with current image info | |
| current_progress = (processed_count - 1) / total_images | |
| progress_msg = f"π Processing image {processed_count}/{total_images}: {img_name[:30]}..." if len(img_name) > 30 else f"π Processing image {processed_count}/{total_images}: {img_name}" | |
| progress(current_progress, desc=progress_msg) | |
| print(progress_msg) | |
| # Use model inference | |
| model_output = self.model_manager.inference(str(image_file_dict[img_name]), prompt) | |
| # Extract binary output using ground truth and all ground truths for keyword extraction | |
| ground_truth = str(row['Ground Truth']) if pd.notna(row['Ground Truth']) else "" | |
| binary_output = extract_binary_output(model_output, ground_truth, all_ground_truths) | |
| filtered_rows.append({ | |
| 'S.No': len(filtered_rows) + 1, | |
| 'Image Name': img_name, | |
| 'Ground Truth': row['Ground Truth'], | |
| 'Binary Output': binary_output, | |
| 'Model Output': model_output, | |
| 'Image Path': str(image_file_dict[img_name]) | |
| }) | |
| # Update progress after successful processing | |
| current_progress = processed_count / total_images | |
| progress_msg = f"β Completed {processed_count}/{total_images} images" | |
| progress(current_progress, desc=progress_msg) | |
| print(f"Successfully processed image {processed_count} of {total_images}") | |
| except Exception as e: | |
| print(f"Error processing image {processed_count} of {total_images}: {str(e)}") | |
| filtered_rows.append({ | |
| 'S.No': len(filtered_rows) + 1, | |
| 'Image Name': img_name, | |
| 'Ground Truth': row['Ground Truth'], | |
| 'Binary Output': 'Enter the output manually', # Default for errors | |
| 'Model Output': f"Error: {str(e)}", | |
| 'Image Path': str(image_file_dict[img_name]) | |
| }) | |
| # Update progress even for errors | |
| current_progress = processed_count / total_images | |
| progress_msg = f"β οΈ Processed {processed_count}/{total_images} images (with errors)" | |
| progress(current_progress, desc=progress_msg) | |
| # Check if processing was completed or stopped | |
| if self.check_stop_flag(): | |
| final_message = f"π Processing stopped by user. Completed {len([r for r in filtered_rows if 'stopped' not in r['Model Output']])}/{total_images} images." | |
| else: | |
| final_message = f"π Successfully completed processing all {total_images} images!" | |
| display_df = pd.DataFrame(filtered_rows)[['S.No', 'Image Name', 'Ground Truth', 'Binary Output', 'Model Output']] | |
| # Save the full dataframe (with Image Path) for preview | |
| self.full_df = pd.DataFrame(filtered_rows) | |
| self.save_results_to_s3(display_df) | |
| print(final_message) | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), display_df, gr.update(visible=False), final_message | |
| def rerun_with_new_prompt( | |
| self, | |
| df: pd.DataFrame, | |
| new_prompt: str, | |
| quantization_type: str, | |
| model_selection: str, | |
| progress: gr.Progress() | |
| ) -> Tuple[Any, ...]: | |
| """Rerun processing with new prompt and clear accuracy data.""" | |
| if df is None or not new_prompt.strip(): | |
| return df, None, None, None, gr.update(visible=False), gr.update(visible=False), "β οΈ Please provide a valid prompt" | |
| # Reset stop flag at the beginning of reprocessing | |
| self.reset_stop_flag() | |
| updated_df = df.copy() | |
| total_images = len(updated_df) | |
| # Collect all ground truth values for unique keyword extraction | |
| all_ground_truths = [str(row['Ground Truth']) for idx, row in updated_df.iterrows() | |
| if pd.notna(row['Ground Truth']) and str(row['Ground Truth']).strip()] | |
| # Get the full dataframe with image paths | |
| if self.full_df is None: | |
| return df, None, None, None, gr.update(visible=False), gr.update(visible=False), "β οΈ No image data available" | |
| # Create a copy of the full dataframe to update | |
| updated_full_df = self.full_df.copy() | |
| # Ensure correct model is loaded | |
| self._ensure_correct_model_loaded(model_selection, quantization_type, progress) | |
| # Initialize progress | |
| progress(0, desc=f"π Starting to reprocess {total_images} images with new prompt...") | |
| print(f"π Starting to reprocess {total_images} images with new prompt...") | |
| for i in range(len(updated_df)): | |
| # Check stop flag before processing each image | |
| if self.check_stop_flag(): | |
| print(f"π Reprocessing stopped by user at image {i + 1}/{total_images}") | |
| # Mark remaining images as not reprocessed in both dataframes | |
| for j in range(i, len(updated_df)): | |
| updated_df.iloc[j, updated_df.columns.get_loc("Model Output")] = "Reprocessing stopped by user" | |
| updated_df.iloc[j, updated_df.columns.get_loc("Binary Output")] = "Not reprocessed (stopped)" | |
| # Also update the full dataframe | |
| if j < len(updated_full_df): | |
| updated_full_df.iloc[j, updated_full_df.columns.get_loc("Model Output")] = "Reprocessing stopped by user" | |
| updated_full_df.iloc[j, updated_full_df.columns.get_loc("Binary Output")] = "Not reprocessed (stopped)" | |
| # Update the full_df reference | |
| self.full_df = updated_full_df | |
| final_message = f"π Reprocessing stopped by user. Completed {i}/{total_images} images." | |
| print(final_message) | |
| return updated_df, None, None, None, gr.update(visible=False), gr.update(visible=False), final_message | |
| try: | |
| # Get image path from full_df | |
| image_path = self.full_df.iloc[i]['Image Path'] | |
| image_name = updated_df.iloc[i]['Image Name'] | |
| ground_truth = str(updated_df.iloc[i]['Ground Truth']) if pd.notna(updated_df.iloc[i]['Ground Truth']) else "" | |
| # Update progress with current image info | |
| current_progress = i / total_images | |
| progress_msg = f"π Reprocessing image {i + 1}/{total_images}: {image_name[:30]}..." if len(image_name) > 30 else f"π Reprocessing image {i + 1}/{total_images}: {image_name}" | |
| progress(current_progress, desc=progress_msg) | |
| print(progress_msg) | |
| # Use model inference with new prompt | |
| model_output = self.model_manager.inference(image_path, new_prompt) | |
| # Update both the display dataframe and the full dataframe | |
| updated_df.iloc[i, updated_df.columns.get_loc("Model Output")] = model_output | |
| updated_full_df.iloc[i, updated_full_df.columns.get_loc("Model Output")] = model_output | |
| # Extract binary output using ground truth and all ground truths for keyword extraction | |
| binary_output = extract_binary_output(model_output, ground_truth, all_ground_truths) | |
| updated_df.iloc[i, updated_df.columns.get_loc("Binary Output")] = binary_output | |
| updated_full_df.iloc[i, updated_full_df.columns.get_loc("Binary Output")] = binary_output | |
| # Update progress after successful processing | |
| current_progress = (i + 1) / total_images | |
| progress_msg = f"β Completed {i + 1}/{total_images} images" | |
| progress(current_progress, desc=progress_msg) | |
| print(f"β Successfully reprocessed image {i + 1}/{total_images}") | |
| except Exception as e: | |
| print(f"β Error reprocessing image {i + 1}/{total_images}: {str(e)}") | |
| error_message = f"Error: {str(e)}" | |
| # Update both dataframes with error information | |
| updated_df.iloc[i, updated_df.columns.get_loc("Model Output")] = error_message | |
| updated_df.iloc[i, updated_df.columns.get_loc("Binary Output")] = "Enter the output manually" | |
| updated_full_df.iloc[i, updated_full_df.columns.get_loc("Model Output")] = error_message | |
| updated_full_df.iloc[i, updated_full_df.columns.get_loc("Binary Output")] = "Enter the output manually" | |
| # Update progress even for errors | |
| current_progress = (i + 1) / total_images | |
| progress_msg = f"β οΈ Processed {i + 1}/{total_images} images (with errors)" | |
| progress(current_progress, desc=progress_msg) | |
| # Update the full_df reference with the updated data | |
| self.full_df = updated_full_df | |
| # Check if reprocessing was completed or stopped | |
| if self.check_stop_flag(): | |
| final_message = f"π Reprocessing stopped by user. Completed reprocessing for some images." | |
| else: | |
| final_message = f"π Successfully completed reprocessing all {total_images} images with new prompt! Click 'Generate Metrics' to see accuracy data." | |
| self.save_results_to_s3(updated_full_df) | |
| print(final_message) | |
| # Return updated dataframe and clear accuracy data (hide section 3) | |
| return updated_df, None, None, None, gr.update(visible=False), gr.update(visible=False), final_message | |
| def save_results_to_s3(self, df): | |
| """Save results to S3 bucket.""" | |
| try: | |
| s3_bucket = os.getenv('AWS_BUCKET') | |
| prefix = os.getenv('AWS_PREFIX') | |
| s3_path = f"{prefix}/{datetime.now().date()}" | |
| date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| csv_file_name = f'{date_time}_model_output.csv' | |
| # create accuracy table | |
| metrics_df, _, cm_values = create_accuracy_table(df) | |
| # save metrics_df to text file | |
| text_file_name = f'{date_time}_evaluation_metrics.txt' | |
| # save metrics_df to text file | |
| with open(text_file_name, 'w') as f: | |
| f.write(metrics_df.to_string() + '\n\n') | |
| f.write(cm_values.to_string()) | |
| # save df to csv | |
| df.to_csv(csv_file_name, index=False) | |
| # upload files to s3 | |
| status = self.upload_file(text_file_name, s3_bucket, f"{s3_path}/{text_file_name}") | |
| print(f"Status of uploading {text_file_name} to {s3_bucket}/{s3_path}/{text_file_name}: {status}") | |
| status = self.upload_file(csv_file_name, s3_bucket, f"{s3_path}/{csv_file_name}") | |
| print(f"Status of uploading {csv_file_name} to {s3_bucket}/{s3_path}/{csv_file_name}: {status}") | |
| # delete files from local | |
| os.remove(text_file_name) | |
| os.remove(csv_file_name) | |
| print(f"Deleted {text_file_name} and {csv_file_name}") | |
| except Exception as e: | |
| print(f"Error saving results to s3: {e}") | |
| if "No valid data" in str(e) or "Need at least 2 different" in str(e): | |
| df.to_csv(csv_file_name, index=False) | |
| status = self.upload_file(csv_file_name, s3_bucket, f"{s3_path}/{csv_file_name}") | |
| print(f"Status of uploading only csv file to {s3_bucket}/{s3_path}/{csv_file_name}: {status}") | |
| os.remove(csv_file_name) | |
| print(f"Deleted {csv_file_name}") | |
| def upload_file(self,file_name, bucket, object_name=None): | |
| """Upload a file to an S3 bucket | |
| :param file_name: File to upload | |
| :param bucket: Bucket to upload to | |
| :param object_name: S3 object name. If not specified then file_name is used | |
| :return: True if file was uploaded, else False | |
| """ | |
| access_key = os.getenv('AWS_ACCESS_KEY_ID') | |
| secret_key = os.getenv('AWS_SECRET_ACCESS_KEY') | |
| # If S3 object_name was not specified, use file_name | |
| if object_name is None: | |
| object_name = os.path.basename(file_name) | |
| # Upload the file | |
| s3_client = boto3.client('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key) | |
| try: | |
| response = s3_client.upload_file(file_name, bucket, object_name) | |
| except Exception as e: | |
| print(f"Error uploading {file_name} to s3: {e}") | |
| return False | |
| return True |