import gradio as gr import torch from safetensors.torch import save_file as torch_save_file import tensorflow as tf from safetensors.keras import save_model as keras_save_model import os import tempfile def convert_to_safetensors(framework, model_file): """ Convert uploaded model files to SafeTensors format """ if not model_file: return gr.Error("Please upload a model file.") # Create a temporary output file output_filename = "model.safetensors" try: if framework == "PyTorch": # Load PyTorch model weights safely state_dict = torch.load( model_file, map_location='cpu', weights_only=True ) # Handle case where full model is loaded instead of just state_dict if hasattr(state_dict, 'state_dict'): state_dict = state_dict.state_dict() elif isinstance(state_dict, torch.nn.Module): state_dict = state_dict.state_dict() # Save to SafeTensors format torch_save_file(state_dict, output_filename) return output_filename elif framework == "TensorFlow": # Load TensorFlow/Keras model model = tf.keras.models.load_model(model_file) # Save to SafeTensors format keras_save_model(model, output_filename) return output_filename else: return gr.Error("Please select a valid framework (PyTorch or TensorFlow).") except Exception as e: error_msg = f"{framework} Conversion Error: {str(e)}" if framework == "PyTorch": error_msg += "\n\nTips:\n• Ensure the file is a valid PyTorch model (.pt, .pth)\n• Model should contain state_dict or be loadable with torch.load()" elif framework == "TensorFlow": error_msg += "\n\nTips:\n• Ensure the file is a valid TensorFlow model (.h5, SavedModel)\n• For SavedModel format, upload as a zip file containing the model directory" return gr.Error(error_msg) # Create the Gradio interface with gr.Blocks( title="SafeTensors Model Converter", theme=gr.themes.Soft() ) as iface: gr.Markdown(""" # šŸ”’ No-Code SafeTensors Model Creator Convert your machine learning models to the secure **SafeTensors** format with zero coding required! ## Why SafeTensors? - **Security**: Prevents arbitrary code execution during model loading - **Speed**: Faster loading times compared to pickle-based formats - **Memory Efficiency**: Zero-copy deserialization - **Cross-Platform**: Works across different ML frameworks ## Supported Formats - **PyTorch**: `.pt`, `.pth` files containing model weights - **TensorFlow**: `.h5` files or SavedModel directories (as zip) """) with gr.Row(): with gr.Column(): framework_dropdown = gr.Dropdown( choices=["PyTorch", "TensorFlow"], label="šŸ”§ Select Framework", info="Choose the framework your model was trained with", value="PyTorch" ) model_upload = gr.File( label="šŸ“ Upload Model File", file_types=[".pt", ".pth", ".h5", ".zip"], info="Upload your model file (.pt/.pth for PyTorch, .h5 for TensorFlow)" ) convert_btn = gr.Button( "šŸš€ Convert to SafeTensors", variant="primary", size="lg" ) with gr.Column(): output_file = gr.File( label="šŸ’¾ Download SafeTensors File", info="Your converted model will appear here" ) gr.Markdown(""" ### šŸ“‹ Usage Instructions 1. **Select Framework**: Choose PyTorch or TensorFlow 2. **Upload Model**: Select your model file from your computer 3. **Convert**: Click the convert button 4. **Download**: Get your secure SafeTensors file ### āš ļø Important Notes - Only model weights are converted (no training code) - Original model architecture code is still needed for inference - Conversion preserves all tensor data and metadata """) # Set up the conversion event convert_btn.click( fn=convert_to_safetensors, inputs=[framework_dropdown, model_upload], outputs=output_file, show_progress=True ) gr.Markdown(""" --- ### šŸ›”ļø Security Benefits SafeTensors format eliminates security risks associated with pickle-based model formats by: - Storing only tensor data (no executable code) - Using a simple, well-defined file format - Enabling safe model sharing and deployment ### šŸ”— Learn More - [SafeTensors Documentation](https://huggingface.co/docs/safetensors) - [Hugging Face Model Hub](https://huggingface.co/models) """) # For Hugging Face Spaces deployment if __name__ == "__main__": iface.launch()