shrikh / app.py
likhonsheikhdev's picture
Rename app py to app.py
bd1f2a2 verified
import gradio as gr
import torch
from safetensors.torch import save_file as torch_save_file
import tensorflow as tf
from safetensors.keras import save_model as keras_save_model
import os
import tempfile
def convert_to_safetensors(framework, model_file):
"""
Convert uploaded model files to SafeTensors format
"""
if not model_file:
return gr.Error("Please upload a model file.")
# Create a temporary output file
output_filename = "model.safetensors"
try:
if framework == "PyTorch":
# Load PyTorch model weights safely
state_dict = torch.load(
model_file,
map_location='cpu',
weights_only=True
)
# Handle case where full model is loaded instead of just state_dict
if hasattr(state_dict, 'state_dict'):
state_dict = state_dict.state_dict()
elif isinstance(state_dict, torch.nn.Module):
state_dict = state_dict.state_dict()
# Save to SafeTensors format
torch_save_file(state_dict, output_filename)
return output_filename
elif framework == "TensorFlow":
# Load TensorFlow/Keras model
model = tf.keras.models.load_model(model_file)
# Save to SafeTensors format
keras_save_model(model, output_filename)
return output_filename
else:
return gr.Error("Please select a valid framework (PyTorch or TensorFlow).")
except Exception as e:
error_msg = f"{framework} Conversion Error: {str(e)}"
if framework == "PyTorch":
error_msg += "\n\nTips:\nβ€’ Ensure the file is a valid PyTorch model (.pt, .pth)\nβ€’ Model should contain state_dict or be loadable with torch.load()"
elif framework == "TensorFlow":
error_msg += "\n\nTips:\nβ€’ Ensure the file is a valid TensorFlow model (.h5, SavedModel)\nβ€’ For SavedModel format, upload as a zip file containing the model directory"
return gr.Error(error_msg)
# Create the Gradio interface
with gr.Blocks(
title="SafeTensors Model Converter",
theme=gr.themes.Soft()
) as iface:
gr.Markdown("""
# πŸ”’ No-Code SafeTensors Model Creator
Convert your machine learning models to the secure **SafeTensors** format with zero coding required!
## Why SafeTensors?
- **Security**: Prevents arbitrary code execution during model loading
- **Speed**: Faster loading times compared to pickle-based formats
- **Memory Efficiency**: Zero-copy deserialization
- **Cross-Platform**: Works across different ML frameworks
## Supported Formats
- **PyTorch**: `.pt`, `.pth` files containing model weights
- **TensorFlow**: `.h5` files or SavedModel directories (as zip)
""")
with gr.Row():
with gr.Column():
framework_dropdown = gr.Dropdown(
choices=["PyTorch", "TensorFlow"],
label="πŸ”§ Select Framework",
info="Choose the framework your model was trained with",
value="PyTorch"
)
model_upload = gr.File(
label="πŸ“ Upload Model File",
file_types=[".pt", ".pth", ".h5", ".zip"],
info="Upload your model file (.pt/.pth for PyTorch, .h5 for TensorFlow)"
)
convert_btn = gr.Button(
"πŸš€ Convert to SafeTensors",
variant="primary",
size="lg"
)
with gr.Column():
output_file = gr.File(
label="πŸ’Ύ Download SafeTensors File",
info="Your converted model will appear here"
)
gr.Markdown("""
### πŸ“‹ Usage Instructions
1. **Select Framework**: Choose PyTorch or TensorFlow
2. **Upload Model**: Select your model file from your computer
3. **Convert**: Click the convert button
4. **Download**: Get your secure SafeTensors file
### ⚠️ Important Notes
- Only model weights are converted (no training code)
- Original model architecture code is still needed for inference
- Conversion preserves all tensor data and metadata
""")
# Set up the conversion event
convert_btn.click(
fn=convert_to_safetensors,
inputs=[framework_dropdown, model_upload],
outputs=output_file,
show_progress=True
)
gr.Markdown("""
---
### πŸ›‘οΈ Security Benefits
SafeTensors format eliminates security risks associated with pickle-based model formats by:
- Storing only tensor data (no executable code)
- Using a simple, well-defined file format
- Enabling safe model sharing and deployment
### πŸ”— Learn More
- [SafeTensors Documentation](https://huggingface.co/docs/safetensors)
- [Hugging Face Model Hub](https://huggingface.co/models)
""")
# For Hugging Face Spaces deployment
if __name__ == "__main__":
iface.launch()