Ahmedbelaid1 commited on
Commit
e065062
·
verified ·
1 Parent(s): 7e74ef2

Upload 41 files

Browse files
Files changed (41) hide show
  1. .dockerignore +55 -0
  2. .gitattributes +2 -35
  3. .gitignore +164 -0
  4. API_README.md +209 -0
  5. DEPLOYMENT_GUIDE.md +271 -0
  6. Dockerfile +56 -0
  7. LICENSE +21 -0
  8. REACT_INTEGRATION.md +549 -0
  9. README.md +143 -11
  10. README_DOCKER.md +158 -0
  11. README_HF_SPACES.md +143 -0
  12. api_example.html +194 -0
  13. api_server.py +403 -0
  14. app.py +171 -0
  15. deploy_colab.ipynb +268 -0
  16. docker-compose.yml +31 -0
  17. gradio_app.py +187 -0
  18. requirements.txt +16 -0
  19. run.py +197 -0
  20. tsr/__pycache__/bake_texture.cpython-313.pyc +0 -0
  21. tsr/__pycache__/system.cpython-313.pyc +0 -0
  22. tsr/__pycache__/utils.cpython-313.pyc +0 -0
  23. tsr/bake_texture.py +191 -0
  24. tsr/models/__pycache__/isosurface.cpython-313.pyc +0 -0
  25. tsr/models/__pycache__/nerf_renderer.cpython-313.pyc +0 -0
  26. tsr/models/__pycache__/network_utils.cpython-313.pyc +0 -0
  27. tsr/models/isosurface.py +64 -0
  28. tsr/models/nerf_renderer.py +180 -0
  29. tsr/models/network_utils.py +124 -0
  30. tsr/models/tokenizers/__pycache__/image.cpython-313.pyc +0 -0
  31. tsr/models/tokenizers/__pycache__/triplane.cpython-313.pyc +0 -0
  32. tsr/models/tokenizers/image.py +66 -0
  33. tsr/models/tokenizers/triplane.py +45 -0
  34. tsr/models/transformer/__pycache__/attention.cpython-313.pyc +0 -0
  35. tsr/models/transformer/__pycache__/basic_transformer_block.cpython-313.pyc +0 -0
  36. tsr/models/transformer/__pycache__/transformer_1d.cpython-313.pyc +0 -0
  37. tsr/models/transformer/attention.py +653 -0
  38. tsr/models/transformer/basic_transformer_block.py +334 -0
  39. tsr/models/transformer/transformer_1d.py +219 -0
  40. tsr/system.py +205 -0
  41. tsr/utils.py +510 -0
.dockerignore ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+ *.egg
11
+
12
+ # Virtual environments
13
+ venv/
14
+ env/
15
+ ENV/
16
+
17
+ # IDE
18
+ .vscode/
19
+ .idea/
20
+ *.swp
21
+ *.swo
22
+ *~
23
+
24
+ # Git
25
+ .git/
26
+ .gitignore
27
+
28
+ # Docker
29
+ Dockerfile
30
+ .dockerignore
31
+
32
+ # Output files
33
+ output/
34
+ *.obj
35
+ *.glb
36
+ *.zip
37
+ *.png
38
+ *.jpg
39
+ *.jpeg
40
+
41
+ # Logs
42
+ *.log
43
+
44
+ # OS
45
+ .DS_Store
46
+ Thumbs.db
47
+
48
+ # Model cache (will be downloaded at runtime)
49
+ .cache/
50
+
51
+ # Temporary files
52
+ *.tmp
53
+ temp/
54
+
55
+
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # default output directory
163
+ output/
164
+ outputs/
API_README.md ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TripoSR REST API
2
+
3
+ A FastAPI-based REST API server for TripoSR 3D mesh generation.
4
+
5
+ ## Starting the Server
6
+
7
+ ```bash
8
+ python api_server.py
9
+ ```
10
+
11
+ Or with uvicorn directly:
12
+
13
+ ```bash
14
+ uvicorn api_server:app --host 0.0.0.0 --port 8000
15
+ ```
16
+
17
+ The API will be available at `http://localhost:8000`
18
+
19
+ ## API Endpoints
20
+
21
+ ### 1. Health Check
22
+
23
+ **GET** `/health`
24
+
25
+ Check if the API is running and get device information.
26
+
27
+ **Response:**
28
+ ```json
29
+ {
30
+ "status": "healthy",
31
+ "device": "cuda:0",
32
+ "cuda_available": true
33
+ }
34
+ ```
35
+
36
+ ### 2. Generate Mesh (File Download)
37
+
38
+ **POST** `/generate`
39
+
40
+ Generate a 3D mesh from an uploaded image and download the mesh file.
41
+
42
+ **Parameters:**
43
+ - `image` (file, required): Image file (PNG, JPG, JPEG)
44
+ - `do_remove_background` (boolean, default: true): Whether to remove background
45
+ - `foreground_ratio` (float, default: 0.85): Ratio of foreground size (0.5-1.0)
46
+ - `mc_resolution` (int, default: 256): Marching cubes resolution (128, 160, 192, 224, 256, 288, 320)
47
+ - `format` (string, default: "obj"): Output format - "obj" or "glb"
48
+
49
+ **Response:** Mesh file download
50
+
51
+ **Example (cURL):**
52
+ ```bash
53
+ curl -X POST "http://localhost:8000/generate" \
54
+ -F "image=@chair.png" \
55
+ -F "do_remove_background=true" \
56
+ -F "foreground_ratio=0.85" \
57
+ -F "mc_resolution=256" \
58
+ -F "format=obj" \
59
+ --output mesh.obj
60
+ ```
61
+
62
+ ### 3. Generate Mesh (Base64)
63
+
64
+ **POST** `/generate-base64`
65
+
66
+ Generate a 3D mesh and return as base64 encoded string.
67
+
68
+ **Parameters:** Same as `/generate`
69
+
70
+ **Response:**
71
+ ```json
72
+ {
73
+ "success": true,
74
+ "format": "obj",
75
+ "mesh": "base64_encoded_mesh_data...",
76
+ "size": 1234567
77
+ }
78
+ ```
79
+
80
+ ## Frontend Integration Examples
81
+
82
+ ### JavaScript/Fetch
83
+
84
+ ```javascript
85
+ const formData = new FormData();
86
+ formData.append('image', imageFile);
87
+ formData.append('do_remove_background', true);
88
+ formData.append('foreground_ratio', 0.85);
89
+ formData.append('mc_resolution', 256);
90
+ formData.append('format', 'obj');
91
+
92
+ const response = await fetch('http://localhost:8000/generate', {
93
+ method: 'POST',
94
+ body: formData
95
+ });
96
+
97
+ if (response.ok) {
98
+ const blob = await response.blob();
99
+ const url = window.URL.createObjectURL(blob);
100
+ // Download or use the mesh file
101
+ const a = document.createElement('a');
102
+ a.href = url;
103
+ a.download = 'mesh.obj';
104
+ a.click();
105
+ }
106
+ ```
107
+
108
+ ### React Example
109
+
110
+ ```jsx
111
+ import { useState } from 'react';
112
+
113
+ function MeshGenerator() {
114
+ const [loading, setLoading] = useState(false);
115
+
116
+ const generateMesh = async (imageFile) => {
117
+ setLoading(true);
118
+ const formData = new FormData();
119
+ formData.append('image', imageFile);
120
+ formData.append('do_remove_background', true);
121
+ formData.append('foreground_ratio', 0.85);
122
+ formData.append('mc_resolution', 256);
123
+ formData.append('format', 'obj');
124
+
125
+ try {
126
+ const response = await fetch('http://localhost:8000/generate', {
127
+ method: 'POST',
128
+ body: formData
129
+ });
130
+
131
+ if (response.ok) {
132
+ const blob = await response.blob();
133
+ // Handle the mesh file
134
+ const url = window.URL.createObjectURL(blob);
135
+ // Download or display
136
+ }
137
+ } catch (error) {
138
+ console.error('Error:', error);
139
+ } finally {
140
+ setLoading(false);
141
+ }
142
+ };
143
+
144
+ return (
145
+ <div>
146
+ <input
147
+ type="file"
148
+ accept="image/*"
149
+ onChange={(e) => generateMesh(e.target.files[0])}
150
+ />
151
+ {loading && <p>Generating mesh...</p>}
152
+ </div>
153
+ );
154
+ }
155
+ ```
156
+
157
+ ### Python Client Example
158
+
159
+ ```python
160
+ import requests
161
+
162
+ url = "http://localhost:8000/generate"
163
+
164
+ with open("chair.png", "rb") as f:
165
+ files = {"image": f}
166
+ data = {
167
+ "do_remove_background": True,
168
+ "foreground_ratio": 0.85,
169
+ "mc_resolution": 256,
170
+ "format": "obj"
171
+ }
172
+ response = requests.post(url, files=files, data=data)
173
+
174
+ if response.status_code == 200:
175
+ with open("mesh.obj", "wb") as out:
176
+ out.write(response.content)
177
+ print("Mesh saved to mesh.obj")
178
+ ```
179
+
180
+ ## CORS Configuration
181
+
182
+ The API is configured to allow CORS from all origins by default. For production, update the `allow_origins` in `api_server.py`:
183
+
184
+ ```python
185
+ app.add_middleware(
186
+ CORSMiddleware,
187
+ allow_origins=["https://your-frontend-domain.com"], # Your frontend URL
188
+ allow_credentials=True,
189
+ allow_methods=["*"],
190
+ allow_headers=["*"],
191
+ )
192
+ ```
193
+
194
+ ## Performance Notes
195
+
196
+ - Model initialization takes ~18-20 seconds on first request
197
+ - Mesh generation typically takes 30-60 seconds depending on resolution
198
+ - GPU (CUDA) is recommended for faster processing
199
+ - Consider implementing request queuing for production use
200
+
201
+ ## Error Handling
202
+
203
+ All endpoints return appropriate HTTP status codes:
204
+ - `200`: Success
205
+ - `400`: Bad request (invalid parameters)
206
+ - `500`: Server error (model processing failed)
207
+
208
+ Error responses include a JSON body with a `detail` field describing the error.
209
+
DEPLOYMENT_GUIDE.md ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Free Cloud Deployment Guide for TripoSR API
2
+
3
+ This guide covers multiple **FREE** options to deploy your TripoSR API in the cloud.
4
+
5
+ ## 📋 Table of Contents
6
+ 1. [Hugging Face Spaces (Recommended)](#1-hugging-face-spaces-recommended)
7
+ 2. [Google Colab](#2-google-colab)
8
+ 3. [Render.com (CPU Only)](#3-rendercom-cpu-only)
9
+ 4. [Railway.app](#4-railwayapp)
10
+
11
+ ---
12
+
13
+ ## 1. Hugging Face Spaces (Recommended) ⭐
14
+
15
+ **Best option** for this project - offers free GPU access and is designed for ML models.
16
+
17
+ ### Features:
18
+ - ✅ Free GPU (T4) available
19
+ - ✅ Persistent deployment
20
+ - ✅ Built-in CI/CD
21
+ - ✅ Public API endpoint
22
+ - ✅ Great for ML models
23
+
24
+ ### Setup Steps:
25
+
26
+ 1. **Create a Hugging Face account** at [huggingface.co](https://huggingface.co)
27
+
28
+ 2. **Create a new Space:**
29
+ - Go to https://huggingface.co/spaces
30
+ - Click "Create new Space"
31
+ - Name: `triposr-api`
32
+ - License: MIT
33
+ - Select SDK: **Docker**
34
+ - Hardware: **CPU basic** (start here, upgrade to GPU if needed)
35
+
36
+ 3. **Clone your Space repository:**
37
+ ```bash
38
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/triposr-api
39
+ cd triposr-api
40
+ ```
41
+
42
+ 4. **Copy these files to the Space:**
43
+ - Copy `Dockerfile.huggingface` as `Dockerfile`
44
+ - Copy `api_server.py`
45
+ - Copy `requirements.txt`
46
+ - Copy the entire `tsr/` directory
47
+ - Copy `README.md`
48
+
49
+ 5. **Create a `README.md` header** (required by HF Spaces):
50
+ ```markdown
51
+ ---
52
+ title: TripoSR API
53
+ emoji: 🎨
54
+ colorFrom: blue
55
+ colorTo: purple
56
+ sdk: docker
57
+ pinned: false
58
+ license: mit
59
+ ---
60
+
61
+ # TripoSR API
62
+
63
+ Fast 3D reconstruction from a single image.
64
+ ```
65
+
66
+ 6. **Push to Hugging Face:**
67
+ ```bash
68
+ git add .
69
+ git commit -m "Initial deployment"
70
+ git push
71
+ ```
72
+
73
+ 7. **Your API will be available at:**
74
+ ```
75
+ https://YOUR_USERNAME-triposr-api.hf.space
76
+ ```
77
+
78
+ ### Upgrade to GPU (if needed):
79
+ - Go to your Space settings
80
+ - Under "Hardware", select **T4 small** (free tier)
81
+ - The Space will rebuild automatically
82
+
83
+ ---
84
+
85
+ ## 2. Google Colab
86
+
87
+ **Good for testing** - Free GPU but sessions expire after inactivity.
88
+
89
+ ### Features:
90
+ - ✅ Free GPU (T4/K80)
91
+ - ❌ Sessions expire (12-hour limit)
92
+ - ❌ Not persistent
93
+ - ✅ Good for testing/demos
94
+
95
+ ### Setup Steps:
96
+
97
+ 1. **Upload the Colab notebook:**
98
+ - Open [Google Colab](https://colab.research.google.com)
99
+ - Upload `deploy_colab.ipynb` (provided in this repo)
100
+
101
+ 2. **Run the notebook:**
102
+ - Enable GPU: Runtime → Change runtime type → GPU
103
+ - Run all cells
104
+ - The notebook will install dependencies and start the server
105
+
106
+ 3. **Access via ngrok tunnel:**
107
+ - The notebook creates a public URL using ngrok
108
+ - URL will be displayed in the output
109
+ - Example: `https://abc123.ngrok.io`
110
+
111
+ **Note:** The URL changes every time you restart the notebook.
112
+
113
+ ---
114
+
115
+ ## 3. Render.com (CPU Only)
116
+
117
+ **Limited functionality** - Free tier is CPU only, so 3D generation will be VERY slow.
118
+
119
+ ### Features:
120
+ - ✅ Free tier available
121
+ - ✅ Persistent deployment
122
+ - ❌ CPU only (slow inference)
123
+ - ✅ Auto-deploy from GitHub
124
+
125
+ ### Setup Steps:
126
+
127
+ 1. **Push code to GitHub:**
128
+ ```bash
129
+ git init
130
+ git add .
131
+ git commit -m "Initial commit"
132
+ git remote add origin YOUR_GITHUB_REPO_URL
133
+ git push -u origin main
134
+ ```
135
+
136
+ 2. **Create Render account** at [render.com](https://render.com)
137
+
138
+ 3. **Create a new Web Service:**
139
+ - Click "New +" → "Web Service"
140
+ - Connect your GitHub repository
141
+ - Name: `triposr-api`
142
+ - Environment: **Docker**
143
+ - Plan: **Free**
144
+
145
+ 4. **Configure:**
146
+ - Build Command: (leave empty - using Dockerfile)
147
+ - Start Command: (leave empty - using Dockerfile CMD)
148
+
149
+ 5. **Deploy:**
150
+ - Click "Create Web Service"
151
+ - Wait for build to complete (~10-15 minutes)
152
+
153
+ 6. **Your API will be available at:**
154
+ ```
155
+ https://triposr-api.onrender.com
156
+ ```
157
+
158
+ **⚠️ Warning:** CPU-only inference will be 10-50x slower than GPU!
159
+
160
+ ---
161
+
162
+ ## 4. Railway.app
163
+
164
+ **Trial credits** - $5 free credits, then paid.
165
+
166
+ ### Features:
167
+ - ✅ $5 free trial credits
168
+ - ✅ Easy deployment
169
+ - ❌ No free GPU
170
+ - ✅ Good for CPU testing
171
+
172
+ ### Setup Steps:
173
+
174
+ 1. **Create Railway account** at [railway.app](https://railway.app)
175
+
176
+ 2. **Create new project:**
177
+ - Click "New Project"
178
+ - Select "Deploy from GitHub repo"
179
+ - Connect your repository
180
+
181
+ 3. **Configure:**
182
+ - Railway auto-detects Dockerfile
183
+ - Add environment variables if needed
184
+
185
+ 4. **Deploy:**
186
+ - Railway builds and deploys automatically
187
+ - Get your public URL from the dashboard
188
+
189
+ ---
190
+
191
+ ## 🎯 Recommendation
192
+
193
+ **For production use:** Hugging Face Spaces with GPU
194
+ - Best balance of features, performance, and cost
195
+ - Designed for ML models
196
+ - Free GPU tier available
197
+
198
+ **For testing/demos:** Google Colab
199
+ - Quick setup
200
+ - Free GPU
201
+ - Good for temporary demos
202
+
203
+ **For CPU-only:** Render.com
204
+ - Persistent deployment
205
+ - But very slow for 3D generation
206
+
207
+ ---
208
+
209
+ ## 📊 Comparison Table
210
+
211
+ | Platform | GPU | Persistent | Free Tier | Best For |
212
+ |----------|-----|------------|-----------|----------|
213
+ | **Hugging Face Spaces** | ✅ T4 | ✅ | ✅ | Production |
214
+ | **Google Colab** | ✅ T4/K80 | ❌ | ✅ | Testing |
215
+ | **Render.com** | ❌ | ✅ | ✅ | CPU demos |
216
+ | **Railway.app** | ❌ | ✅ | $5 credits | Trial |
217
+
218
+ ---
219
+
220
+ ## 🔧 Testing Your Deployment
221
+
222
+ Once deployed, test your API:
223
+
224
+ ```bash
225
+ # Health check
226
+ curl https://YOUR_DEPLOYMENT_URL/health
227
+
228
+ # Generate 3D model
229
+ curl -X POST https://YOUR_DEPLOYMENT_URL/generate \
230
+ -F "image=@test_image.png" \
231
+ -F "format=obj" \
232
+ -F "bake_texture_flag=true" \
233
+ -o output.zip
234
+ ```
235
+
236
+ ---
237
+
238
+ ## 📝 Notes
239
+
240
+ - **Model size:** The TripoSR model is ~1.5GB and will be downloaded on first run
241
+ - **Memory requirements:** Minimum 8GB RAM, 6GB VRAM for GPU
242
+ - **Cold starts:** First request may take 30-60 seconds to load the model
243
+ - **Rate limits:** Free tiers may have rate limits or usage quotas
244
+
245
+ ---
246
+
247
+ ## 🆘 Troubleshooting
248
+
249
+ ### "Out of memory" errors
250
+ - Reduce `mc_resolution` parameter (default: 256)
251
+ - Use smaller images
252
+ - Upgrade to larger instance
253
+
254
+ ### Slow generation
255
+ - Ensure GPU is enabled
256
+ - Check if running on CPU (much slower)
257
+ - Monitor instance resources
258
+
259
+ ### Build failures
260
+ - Check Docker logs
261
+ - Ensure all dependencies are in `requirements.txt`
262
+ - Verify CUDA compatibility
263
+
264
+ ---
265
+
266
+ ## 📚 Additional Resources
267
+
268
+ - [TripoSR Paper](https://arxiv.org/abs/2403.02151)
269
+ - [Hugging Face Spaces Docs](https://huggingface.co/docs/hub/spaces)
270
+ - [Render Docs](https://render.com/docs)
271
+ - [Railway Docs](https://docs.railway.app)
Dockerfile ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces optimized Dockerfile
2
+ # This uses a lighter base image suitable for HF Spaces
3
+
4
+ FROM python:3.10-slim
5
+
6
+ # Set environment variables
7
+ ENV DEBIAN_FRONTEND=noninteractive
8
+ ENV PYTHONUNBUFFERED=1
9
+ ENV TRANSFORMERS_CACHE=/tmp/transformers_cache
10
+ ENV HF_HOME=/tmp/hf_home
11
+
12
+ # Install system dependencies
13
+ RUN apt-get update && apt-get install -y \
14
+ build-essential \
15
+ gcc \
16
+ g++ \
17
+ git \
18
+ ffmpeg \
19
+ libsm6 \
20
+ libxext6 \
21
+ libxrender-dev \
22
+ libgomp1 \
23
+ curl \
24
+ && rm -rf /var/lib/apt/lists/*
25
+
26
+ # Set up the application directory
27
+ WORKDIR /app
28
+
29
+ # Copy requirements file first for better Docker layer caching
30
+ COPY requirements.txt /app/
31
+
32
+ # Upgrade pip and setuptools
33
+ RUN pip install --no-cache-dir --upgrade pip setuptools wheel
34
+
35
+ # Install PyTorch CPU version (HF Spaces will use CPU by default, GPU if upgraded)
36
+ # Using CPU version to reduce image size - will auto-detect GPU if available
37
+ RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
38
+
39
+ # Install other dependencies
40
+ RUN pip install --no-cache-dir -r requirements.txt
41
+
42
+ # Copy the rest of the application code
43
+ COPY . /app
44
+
45
+ # Create output directory for temporary files
46
+ RUN mkdir -p /app/output
47
+
48
+ # Expose the port (HF Spaces uses 7860 by default)
49
+ EXPOSE 7860
50
+
51
+ # Health check
52
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=120s --retries=3 \
53
+ CMD curl -f http://localhost:7860/health || exit 1
54
+
55
+ # Run the server on port 7860 (required by HF Spaces)
56
+ CMD ["python", "-m", "uvicorn", "api_server:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Tripo AI & Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
REACT_INTEGRATION.md ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TripoSR API - React Integration Guide
2
+
3
+ This guide shows how to integrate TripoSR API into your React + Supabase project.
4
+
5
+ ## 1. Update CORS in API Server
6
+
7
+ First, update `api_server.py` to allow your React app's origin:
8
+
9
+ ```python
10
+ # In api_server.py, update the CORS middleware:
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=[
14
+ "http://localhost:3000", # React dev server
15
+ "http://localhost:5173", # Vite dev server
16
+ "http://localhost:8080", # Other common ports
17
+ # Add your production domain here
18
+ ],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+ ```
24
+
25
+ ## 2. React Hook for TripoSR API
26
+
27
+ Create a custom hook: `src/hooks/useTripoSR.js` or `src/hooks/useTripoSR.ts`
28
+
29
+ ### TypeScript Version:
30
+
31
+ ```typescript
32
+ import { useState } from 'react';
33
+
34
+ interface GenerateMeshParams {
35
+ doRemoveBackground?: boolean;
36
+ foregroundRatio?: number;
37
+ mcResolution?: number;
38
+ format?: 'obj' | 'glb';
39
+ }
40
+
41
+ interface UseTripoSRReturn {
42
+ generateMesh: (imageFile: File, params?: GenerateMeshParams) => Promise<Blob | null>;
43
+ generateMeshBase64: (imageFile: File, params?: GenerateMeshParams) => Promise<string | null>;
44
+ loading: boolean;
45
+ error: string | null;
46
+ }
47
+
48
+ const API_URL = process.env.REACT_APP_TRIPOSR_API_URL || 'http://localhost:8000';
49
+
50
+ export const useTripoSR = (): UseTripoSRReturn => {
51
+ const [loading, setLoading] = useState(false);
52
+ const [error, setError] = useState<string | null>(null);
53
+
54
+ const generateMesh = async (
55
+ imageFile: File,
56
+ params: GenerateMeshParams = {}
57
+ ): Promise<Blob | null> => {
58
+ setLoading(true);
59
+ setError(null);
60
+
61
+ try {
62
+ const formData = new FormData();
63
+ formData.append('image', imageFile);
64
+ formData.append('do_remove_background', String(params.doRemoveBackground ?? true));
65
+ formData.append('foreground_ratio', String(params.foregroundRatio ?? 0.85));
66
+ formData.append('mc_resolution', String(params.mcResolution ?? 256));
67
+ formData.append('format', params.format ?? 'obj');
68
+
69
+ const response = await fetch(`${API_URL}/generate`, {
70
+ method: 'POST',
71
+ body: formData,
72
+ });
73
+
74
+ if (!response.ok) {
75
+ const errorData = await response.json().catch(() => ({ detail: 'Unknown error' }));
76
+ throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
77
+ }
78
+
79
+ const blob = await response.blob();
80
+ return blob;
81
+ } catch (err) {
82
+ const errorMessage = err instanceof Error ? err.message : 'Failed to generate mesh';
83
+ setError(errorMessage);
84
+ console.error('TripoSR API error:', err);
85
+ return null;
86
+ } finally {
87
+ setLoading(false);
88
+ }
89
+ };
90
+
91
+ const generateMeshBase64 = async (
92
+ imageFile: File,
93
+ params: GenerateMeshParams = {}
94
+ ): Promise<string | null> => {
95
+ setLoading(true);
96
+ setError(null);
97
+
98
+ try {
99
+ const formData = new FormData();
100
+ formData.append('image', imageFile);
101
+ formData.append('do_remove_background', String(params.doRemoveBackground ?? true));
102
+ formData.append('foreground_ratio', String(params.foregroundRatio ?? 0.85));
103
+ formData.append('mc_resolution', String(params.mcResolution ?? 256));
104
+ formData.append('format', params.format ?? 'obj');
105
+
106
+ const response = await fetch(`${API_URL}/generate-base64`, {
107
+ method: 'POST',
108
+ body: formData,
109
+ });
110
+
111
+ if (!response.ok) {
112
+ const errorData = await response.json().catch(() => ({ detail: 'Unknown error' }));
113
+ throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
114
+ }
115
+
116
+ const data = await response.json();
117
+ return data.mesh; // Base64 encoded mesh
118
+ } catch (err) {
119
+ const errorMessage = err instanceof Error ? err.message : 'Failed to generate mesh';
120
+ setError(errorMessage);
121
+ console.error('TripoSR API error:', err);
122
+ return null;
123
+ } finally {
124
+ setLoading(false);
125
+ }
126
+ };
127
+
128
+ return { generateMesh, generateMeshBase64, loading, error };
129
+ };
130
+ ```
131
+
132
+ ### JavaScript Version:
133
+
134
+ ```javascript
135
+ import { useState } from 'react';
136
+
137
+ const API_URL = process.env.REACT_APP_TRIPOSR_API_URL || 'http://localhost:8000';
138
+
139
+ export const useTripoSR = () => {
140
+ const [loading, setLoading] = useState(false);
141
+ const [error, setError] = useState(null);
142
+
143
+ const generateMesh = async (imageFile, params = {}) => {
144
+ setLoading(true);
145
+ setError(null);
146
+
147
+ try {
148
+ const formData = new FormData();
149
+ formData.append('image', imageFile);
150
+ formData.append('do_remove_background', String(params.doRemoveBackground ?? true));
151
+ formData.append('foreground_ratio', String(params.foregroundRatio ?? 0.85));
152
+ formData.append('mc_resolution', String(params.mcResolution ?? 256));
153
+ formData.append('format', params.format ?? 'obj');
154
+
155
+ const response = await fetch(`${API_URL}/generate`, {
156
+ method: 'POST',
157
+ body: formData,
158
+ });
159
+
160
+ if (!response.ok) {
161
+ const errorData = await response.json().catch(() => ({ detail: 'Unknown error' }));
162
+ throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
163
+ }
164
+
165
+ const blob = await response.blob();
166
+ return blob;
167
+ } catch (err) {
168
+ const errorMessage = err.message || 'Failed to generate mesh';
169
+ setError(errorMessage);
170
+ console.error('TripoSR API error:', err);
171
+ return null;
172
+ } finally {
173
+ setLoading(false);
174
+ }
175
+ };
176
+
177
+ const generateMeshBase64 = async (imageFile, params = {}) => {
178
+ setLoading(true);
179
+ setError(null);
180
+
181
+ try {
182
+ const formData = new FormData();
183
+ formData.append('image', imageFile);
184
+ formData.append('do_remove_background', String(params.doRemoveBackground ?? true));
185
+ formData.append('foreground_ratio', String(params.foregroundRatio ?? 0.85));
186
+ formData.append('mc_resolution', String(params.mcResolution ?? 256));
187
+ formData.append('format', params.format ?? 'obj');
188
+
189
+ const response = await fetch(`${API_URL}/generate-base64`, {
190
+ method: 'POST',
191
+ body: formData,
192
+ });
193
+
194
+ if (!response.ok) {
195
+ const errorData = await response.json().catch(() => ({ detail: 'Unknown error' }));
196
+ throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
197
+ }
198
+
199
+ const data = await response.json();
200
+ return data.mesh;
201
+ } catch (err) {
202
+ const errorMessage = err.message || 'Failed to generate mesh';
203
+ setError(errorMessage);
204
+ console.error('TripoSR API error:', err);
205
+ return null;
206
+ } finally {
207
+ setLoading(false);
208
+ }
209
+ };
210
+
211
+ return { generateMesh, generateMeshBase64, loading, error };
212
+ };
213
+ ```
214
+
215
+ ## 3. React Component Example
216
+
217
+ Create `src/components/MeshGenerator.tsx` or `.jsx`:
218
+
219
+ ```tsx
220
+ import React, { useState } from 'react';
221
+ import { useTripoSR } from '../hooks/useTripoSR';
222
+
223
+ const MeshGenerator: React.FC = () => {
224
+ const [selectedFile, setSelectedFile] = useState<File | null>(null);
225
+ const [preview, setPreview] = useState<string | null>(null);
226
+ const [meshUrl, setMeshUrl] = useState<string | null>(null);
227
+ const { generateMesh, loading, error } = useTripoSR();
228
+
229
+ const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
230
+ const file = e.target.files?.[0];
231
+ if (file) {
232
+ setSelectedFile(file);
233
+ const reader = new FileReader();
234
+ reader.onloadend = () => {
235
+ setPreview(reader.result as string);
236
+ };
237
+ reader.readAsDataURL(file);
238
+ }
239
+ };
240
+
241
+ const handleGenerate = async () => {
242
+ if (!selectedFile) return;
243
+
244
+ const blob = await generateMesh(selectedFile, {
245
+ doRemoveBackground: true,
246
+ foregroundRatio: 0.85,
247
+ mcResolution: 256,
248
+ format: 'obj',
249
+ });
250
+
251
+ if (blob) {
252
+ const url = window.URL.createObjectURL(blob);
253
+ setMeshUrl(url);
254
+ }
255
+ };
256
+
257
+ const handleDownload = () => {
258
+ if (meshUrl) {
259
+ const a = document.createElement('a');
260
+ a.href = meshUrl;
261
+ a.download = 'mesh.obj';
262
+ document.body.appendChild(a);
263
+ a.click();
264
+ document.body.removeChild(a);
265
+ }
266
+ };
267
+
268
+ return (
269
+ <div className="mesh-generator">
270
+ <h2>3D Mesh Generator</h2>
271
+
272
+ <div className="upload-section">
273
+ <input
274
+ type="file"
275
+ accept="image/*"
276
+ onChange={handleFileChange}
277
+ disabled={loading}
278
+ />
279
+ {preview && (
280
+ <img src={preview} alt="Preview" style={{ maxWidth: '300px', marginTop: '10px' }} />
281
+ )}
282
+ </div>
283
+
284
+ <button onClick={handleGenerate} disabled={!selectedFile || loading}>
285
+ {loading ? 'Generating...' : 'Generate Mesh'}
286
+ </button>
287
+
288
+ {error && <div className="error">Error: {error}</div>}
289
+
290
+ {meshUrl && (
291
+ <div className="result-section">
292
+ <p>Mesh generated successfully!</p>
293
+ <button onClick={handleDownload}>Download Mesh</button>
294
+ {/* You can also use a 3D viewer here */}
295
+ </div>
296
+ )}
297
+ </div>
298
+ );
299
+ };
300
+
301
+ export default MeshGenerator;
302
+ ```
303
+
304
+ ## 4. Integration with Supabase Storage
305
+
306
+ If you want to store mesh files in Supabase:
307
+
308
+ ```typescript
309
+ import { useTripoSR } from '../hooks/useTripoSR';
310
+ import { supabase } from '../lib/supabase'; // Your Supabase client
311
+
312
+ const MeshGeneratorWithSupabase: React.FC = () => {
313
+ const { generateMesh, loading, error } = useTripoSR();
314
+ const [uploading, setUploading] = useState(false);
315
+
316
+ const handleGenerateAndUpload = async (imageFile: File, userId: string) => {
317
+ // Generate mesh
318
+ const blob = await generateMesh(imageFile);
319
+ if (!blob) return;
320
+
321
+ // Upload to Supabase Storage
322
+ setUploading(true);
323
+ try {
324
+ const fileName = `${userId}/${Date.now()}_mesh.obj`;
325
+ const { data, error: uploadError } = await supabase.storage
326
+ .from('meshes') // Your storage bucket name
327
+ .upload(fileName, blob, {
328
+ contentType: 'application/octet-stream',
329
+ upsert: false,
330
+ });
331
+
332
+ if (uploadError) throw uploadError;
333
+
334
+ // Get public URL
335
+ const { data: urlData } = supabase.storage
336
+ .from('meshes')
337
+ .getPublicUrl(fileName);
338
+
339
+ console.log('Mesh uploaded:', urlData.publicUrl);
340
+ return urlData.publicUrl;
341
+ } catch (err) {
342
+ console.error('Upload error:', err);
343
+ } finally {
344
+ setUploading(false);
345
+ }
346
+ };
347
+
348
+ // ... rest of component
349
+ };
350
+ ```
351
+
352
+ ## 5. Environment Variables
353
+
354
+ Create `.env.local` in your React project:
355
+
356
+ ```env
357
+ REACT_APP_TRIPOSR_API_URL=http://localhost:8000
358
+ ```
359
+
360
+ For production, update to your deployed API URL.
361
+
362
+ ## 6. 3D Mesh Viewer Integration
363
+
364
+ To display the mesh in your React app, you can use libraries like:
365
+
366
+ - **react-three-fiber** + **drei**
367
+ - **@react-three/viewer**
368
+ - **model-viewer** (web component)
369
+
370
+ Example with `model-viewer`:
371
+
372
+ ```tsx
373
+ import '@google/model-viewer';
374
+
375
+ const MeshViewer: React.FC<{ meshUrl: string }> = ({ meshUrl }) => {
376
+ return (
377
+ <model-viewer
378
+ src={meshUrl}
379
+ alt="3D Mesh"
380
+ auto-rotate
381
+ camera-controls
382
+ style={{ width: '100%', height: '500px' }}
383
+ />
384
+ );
385
+ };
386
+ ```
387
+
388
+ ## 7. Complete Example with All Features
389
+
390
+ ```tsx
391
+ import React, { useState } from 'react';
392
+ import { useTripoSR } from '../hooks/useTripoSR';
393
+ import { supabase } from '../lib/supabase';
394
+
395
+ const CompleteMeshGenerator: React.FC = () => {
396
+ const [file, setFile] = useState<File | null>(null);
397
+ const [preview, setPreview] = useState<string | null>(null);
398
+ const [meshUrl, setMeshUrl] = useState<string | null>(null);
399
+ const [foregroundRatio, setForegroundRatio] = useState(0.85);
400
+ const [resolution, setResolution] = useState(256);
401
+ const [format, setFormat] = useState<'obj' | 'glb'>('obj');
402
+
403
+ const { generateMesh, loading, error } = useTripoSR();
404
+
405
+ const handleFileSelect = (e: React.ChangeEvent<HTMLInputElement>) => {
406
+ const selected = e.target.files?.[0];
407
+ if (selected) {
408
+ setFile(selected);
409
+ const reader = new FileReader();
410
+ reader.onload = () => setPreview(reader.result as string);
411
+ reader.readAsDataURL(selected);
412
+ }
413
+ };
414
+
415
+ const handleGenerate = async () => {
416
+ if (!file) return;
417
+
418
+ const blob = await generateMesh(file, {
419
+ doRemoveBackground: true,
420
+ foregroundRatio,
421
+ mcResolution: resolution,
422
+ format,
423
+ });
424
+
425
+ if (blob) {
426
+ const url = window.URL.createObjectURL(blob);
427
+ setMeshUrl(url);
428
+ }
429
+ };
430
+
431
+ const handleSaveToSupabase = async () => {
432
+ if (!meshUrl || !file) return;
433
+
434
+ const { data: { user } } = await supabase.auth.getUser();
435
+ if (!user) {
436
+ alert('Please log in to save meshes');
437
+ return;
438
+ }
439
+
440
+ const response = await fetch(meshUrl);
441
+ const blob = await response.blob();
442
+ const fileName = `${user.id}/${Date.now()}_mesh.${format}`;
443
+
444
+ const { error } = await supabase.storage
445
+ .from('meshes')
446
+ .upload(fileName, blob);
447
+
448
+ if (error) {
449
+ console.error('Upload error:', error);
450
+ } else {
451
+ alert('Mesh saved to Supabase!');
452
+ }
453
+ };
454
+
455
+ return (
456
+ <div style={{ padding: '20px' }}>
457
+ <h1>3D Mesh Generator</h1>
458
+
459
+ <div>
460
+ <input type="file" accept="image/*" onChange={handleFileSelect} />
461
+ {preview && <img src={preview} alt="Preview" style={{ maxWidth: '300px' }} />}
462
+ </div>
463
+
464
+ <div style={{ margin: '20px 0' }}>
465
+ <label>
466
+ Foreground Ratio: {foregroundRatio}
467
+ <input
468
+ type="range"
469
+ min="0.5"
470
+ max="1.0"
471
+ step="0.05"
472
+ value={foregroundRatio}
473
+ onChange={(e) => setForegroundRatio(parseFloat(e.target.value))}
474
+ />
475
+ </label>
476
+ </div>
477
+
478
+ <div style={{ margin: '20px 0' }}>
479
+ <label>
480
+ Resolution: {resolution}
481
+ <input
482
+ type="range"
483
+ min="128"
484
+ max="320"
485
+ step="32"
486
+ value={resolution}
487
+ onChange={(e) => setResolution(parseInt(e.target.value))}
488
+ />
489
+ </label>
490
+ </div>
491
+
492
+ <div style={{ margin: '20px 0' }}>
493
+ <label>
494
+ Format:
495
+ <select value={format} onChange={(e) => setFormat(e.target.value as 'obj' | 'glb')}>
496
+ <option value="obj">OBJ</option>
497
+ <option value="glb">GLB</option>
498
+ </select>
499
+ </label>
500
+ </div>
501
+
502
+ <button onClick={handleGenerate} disabled={!file || loading}>
503
+ {loading ? 'Generating...' : 'Generate Mesh'}
504
+ </button>
505
+
506
+ {error && <div style={{ color: 'red' }}>Error: {error}</div>}
507
+
508
+ {meshUrl && (
509
+ <div>
510
+ <p>Mesh generated!</p>
511
+ <a href={meshUrl} download={`mesh.${format}`}>
512
+ <button>Download</button>
513
+ </a>
514
+ <button onClick={handleSaveToSupabase}>Save to Supabase</button>
515
+ </div>
516
+ )}
517
+ </div>
518
+ );
519
+ };
520
+
521
+ export default CompleteMeshGenerator;
522
+ ```
523
+
524
+ ## 8. API Health Check
525
+
526
+ Add a health check on app load:
527
+
528
+ ```typescript
529
+ useEffect(() => {
530
+ const checkAPI = async () => {
531
+ try {
532
+ const response = await fetch(`${API_URL}/health`);
533
+ const data = await response.json();
534
+ console.log('TripoSR API status:', data);
535
+ } catch (err) {
536
+ console.error('TripoSR API is not available:', err);
537
+ }
538
+ };
539
+ checkAPI();
540
+ }, []);
541
+ ```
542
+
543
+ ## Notes
544
+
545
+ - Make sure your TripoSR API server is running before using the React app
546
+ - The API takes 30-60 seconds to generate a mesh, so show appropriate loading states
547
+ - Consider implementing request cancellation for better UX
548
+ - For production, deploy the API server and update the `REACT_APP_TRIPOSR_API_URL` environment variable
549
+
README.md CHANGED
@@ -1,11 +1,143 @@
1
- ---
2
- title: Triposr
3
- emoji: 🌍
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TripoSR API
3
+ emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ # TripoSR API 🎨
12
+
13
+ Fast 3D reconstruction from a single image using TripoSR.
14
+
15
+ ## 🚀 Features
16
+
17
+ - **Fast 3D Generation**: Convert images to 3D models in seconds
18
+ - **Multiple Formats**: Export as OBJ or GLB
19
+ - **Texture Baking**: Optional texture atlas generation
20
+ - **Background Removal**: Automatic background removal
21
+ - **REST API**: Easy integration with any frontend
22
+
23
+ ## 📡 API Endpoints
24
+
25
+ ### Health Check
26
+ ```bash
27
+ GET /health
28
+ ```
29
+
30
+ ### Generate 3D Model
31
+ ```bash
32
+ POST /generate
33
+ ```
34
+
35
+ **Parameters:**
36
+ - `image`: Image file (PNG, JPG, JPEG)
37
+ - `do_remove_background`: Remove background (default: true)
38
+ - `foreground_ratio`: Foreground size ratio (default: 0.85)
39
+ - `mc_resolution`: Mesh resolution (default: 256)
40
+ - `format`: Output format - "obj" or "glb" (default: "obj")
41
+ - `bake_texture_flag`: Bake texture (default: true)
42
+ - `texture_resolution`: Texture resolution (default: 2048)
43
+ - `orientation`: Mesh orientation - "standard", "gradio", or "none" (default: "standard")
44
+
45
+ **Returns:**
46
+ - ZIP file containing mesh and texture (if texture baking enabled)
47
+ - Or mesh file only (if texture baking disabled)
48
+
49
+ ### Generate 3D Model (Base64)
50
+ ```bash
51
+ POST /generate-base64
52
+ ```
53
+
54
+ Same parameters as `/generate`, but returns JSON with base64-encoded mesh and texture.
55
+
56
+ ## 🧪 Example Usage
57
+
58
+ ### cURL
59
+ ```bash
60
+ curl -X POST https://YOUR-SPACE-URL/generate \
61
+ -F "image=@your_image.png" \
62
+ -F "format=obj" \
63
+ -F "bake_texture_flag=true" \
64
+ -o output.zip
65
+ ```
66
+
67
+ ### Python
68
+ ```python
69
+ import requests
70
+
71
+ url = "https://YOUR-SPACE-URL/generate"
72
+ files = {"image": open("your_image.png", "rb")}
73
+ data = {
74
+ "format": "obj",
75
+ "bake_texture_flag": True,
76
+ "mc_resolution": 256
77
+ }
78
+
79
+ response = requests.post(url, files=files, data=data)
80
+ with open("output.zip", "wb") as f:
81
+ f.write(response.content)
82
+ ```
83
+
84
+ ### JavaScript
85
+ ```javascript
86
+ const formData = new FormData();
87
+ formData.append('image', fileInput.files[0]);
88
+ formData.append('format', 'obj');
89
+ formData.append('bake_texture_flag', 'true');
90
+
91
+ const response = await fetch('https://YOUR-SPACE-URL/generate', {
92
+ method: 'POST',
93
+ body: formData
94
+ });
95
+
96
+ const blob = await response.blob();
97
+ // Download or process the blob
98
+ ```
99
+
100
+ ## ⚙️ Configuration
101
+
102
+ ### GPU Support
103
+ This Space can run on CPU or GPU:
104
+ - **CPU**: Slower but free
105
+ - **GPU (T4)**: Much faster, may require upgrade in Space settings
106
+
107
+ To upgrade to GPU:
108
+ 1. Go to Space settings
109
+ 2. Select "T4 small" under Hardware
110
+ 3. Space will rebuild automatically
111
+
112
+ ### Performance
113
+ - **CPU**: ~30-60 seconds per image
114
+ - **GPU (T4)**: ~5-10 seconds per image
115
+ - **GPU (A100)**: ~1-2 seconds per image
116
+
117
+ ## 📚 Documentation
118
+
119
+ - [TripoSR Paper](https://arxiv.org/abs/2403.02151)
120
+ - [API Documentation](API_README.md)
121
+ - [React Integration Guide](REACT_INTEGRATION.md)
122
+
123
+ ## 🔗 Links
124
+
125
+ - [GitHub Repository](https://github.com/Ahmedbelaid/TripoSR-api)
126
+ - [Original TripoSR](https://github.com/VAST-AI-Research/TripoSR)
127
+ - [Stability AI](https://stability.ai/)
128
+ - [Tripo AI](https://www.tripo3d.ai/)
129
+
130
+ ## 📄 License
131
+
132
+ MIT License - see [LICENSE](LICENSE) file for details.
133
+
134
+ ## 🙏 Credits
135
+
136
+ - **TripoSR Model**: [Stability AI](https://stability.ai/) and [Tripo AI](https://www.tripo3d.ai/)
137
+ - **API Implementation**: Community contribution
138
+
139
+ ## 🆘 Support
140
+
141
+ For issues or questions:
142
+ - Open an issue on [GitHub](https://github.com/Ahmedbelaid/TripoSR-api/issues)
143
+ - Join the [Discord](https://discord.gg/mvS9mCfMnQ)
README_DOCKER.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker Setup for TripoSR API
2
+
3
+ This guide explains how to build and run the TripoSR API using Docker.
4
+
5
+ ## Prerequisites
6
+
7
+ - Docker installed on your system
8
+ - (Optional) NVIDIA Docker runtime (nvidia-docker2) for GPU support
9
+ - At least 8GB of available RAM
10
+ - (For GPU) NVIDIA GPU with CUDA support
11
+
12
+ **Note:** This Dockerfile uses the `-devel` PyTorch image (instead of `-runtime`) because `torchmcubes` requires the CUDA development toolkit to compile with CUDA support. This results in a larger image size (~8-10GB) but is necessary for proper compilation.
13
+
14
+ ## Building the Docker Image
15
+
16
+ ### Basic Build
17
+
18
+ ```bash
19
+ docker build -t triposr-api:latest .
20
+ ```
21
+
22
+ ### Build with Specific Tag
23
+
24
+ ```bash
25
+ docker build -t triposr-api:v1.0 .
26
+ ```
27
+
28
+ ## Running the Container
29
+
30
+ ### CPU Mode
31
+
32
+ ```bash
33
+ docker run -d \
34
+ --name triposr-api \
35
+ -p 8000:8000 \
36
+ triposr-api:latest
37
+ ```
38
+
39
+ ### GPU Mode (NVIDIA)
40
+
41
+ First, ensure you have `nvidia-docker2` installed. Then run:
42
+
43
+ ```bash
44
+ docker run -d \
45
+ --name triposr-api \
46
+ --gpus all \
47
+ -p 8000:8000 \
48
+ -e CUDA_VISIBLE_DEVICES=0 \
49
+ triposr-api:latest
50
+ ```
51
+
52
+ ### Using Docker Compose
53
+
54
+ For easier management, use Docker Compose:
55
+
56
+ ```bash
57
+ # Start the service
58
+ docker-compose up -d
59
+
60
+ # View logs
61
+ docker-compose logs -f
62
+
63
+ # Stop the service
64
+ docker-compose down
65
+ ```
66
+
67
+ For GPU support with Docker Compose, uncomment the GPU-related lines in `docker-compose.yml`.
68
+
69
+ ## Verifying the Installation
70
+
71
+ Once the container is running, check the health endpoint:
72
+
73
+ ```bash
74
+ curl http://localhost:8000/health
75
+ ```
76
+
77
+ You should see a response like:
78
+ ```json
79
+ {
80
+ "status": "healthy",
81
+ "device": "cuda:0",
82
+ "cuda_available": true
83
+ }
84
+ ```
85
+
86
+ ## API Usage
87
+
88
+ The API will be available at `http://localhost:8000`. See `API_README.md` for detailed API documentation.
89
+
90
+ ### Example Request
91
+
92
+ ```bash
93
+ curl -X POST "http://localhost:8000/generate" \
94
+ -F "image=@your_image.jpg" \
95
+ -F "orientation=standard" \
96
+ -F "format=obj" \
97
+ --output mesh.zip
98
+ ```
99
+
100
+ ## Troubleshooting
101
+
102
+ ### Container Fails to Start
103
+
104
+ 1. Check logs: `docker logs triposr-api`
105
+ 2. Ensure port 8000 is not already in use
106
+ 3. Verify you have enough memory (at least 8GB recommended)
107
+
108
+ ### CUDA/GPU Issues
109
+
110
+ 1. Verify NVIDIA Docker runtime: `docker run --rm --gpus all nvidia/cuda:11.7.0-base-ubuntu20.04 nvidia-smi`
111
+ 2. Check CUDA availability in container: `docker exec triposr-api python -c "import torch; print(torch.cuda.is_available())"`
112
+
113
+ ### Out of Memory
114
+
115
+ If you encounter OOM errors:
116
+ - Reduce `mc_resolution` parameter (default: 256, try 128 or 64)
117
+ - Reduce `texture_resolution` parameter (default: 2048, try 1024)
118
+ - Use CPU mode if GPU memory is limited
119
+
120
+ ### Build Errors
121
+
122
+ If the build fails:
123
+ - Ensure you have a stable internet connection (model will be downloaded)
124
+ - Check that all dependencies in `requirements.txt` are valid
125
+ - Try building with `--no-cache`: `docker build --no-cache -t triposr-api:latest .`
126
+
127
+ ## Environment Variables
128
+
129
+ - `CUDA_VISIBLE_DEVICES`: Set to specific GPU ID (e.g., "0") or empty string for CPU mode
130
+
131
+ ## Volumes
132
+
133
+ The container can mount volumes for:
134
+ - `/app/output`: Output directory for generated meshes (optional)
135
+
136
+ Example:
137
+ ```bash
138
+ docker run -d \
139
+ --name triposr-api \
140
+ -p 8000:8000 \
141
+ -v $(pwd)/output:/app/output \
142
+ triposr-api:latest
143
+ ```
144
+
145
+ ## Stopping and Removing
146
+
147
+ ```bash
148
+ # Stop the container
149
+ docker stop triposr-api
150
+
151
+ # Remove the container
152
+ docker rm triposr-api
153
+
154
+ # Remove the image
155
+ docker rmi triposr-api:latest
156
+ ```
157
+
158
+
README_HF_SPACES.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TripoSR API
3
+ emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ # TripoSR API 🎨
12
+
13
+ Fast 3D reconstruction from a single image using TripoSR.
14
+
15
+ ## 🚀 Features
16
+
17
+ - **Fast 3D Generation**: Convert images to 3D models in seconds
18
+ - **Multiple Formats**: Export as OBJ or GLB
19
+ - **Texture Baking**: Optional texture atlas generation
20
+ - **Background Removal**: Automatic background removal
21
+ - **REST API**: Easy integration with any frontend
22
+
23
+ ## 📡 API Endpoints
24
+
25
+ ### Health Check
26
+ ```bash
27
+ GET /health
28
+ ```
29
+
30
+ ### Generate 3D Model
31
+ ```bash
32
+ POST /generate
33
+ ```
34
+
35
+ **Parameters:**
36
+ - `image`: Image file (PNG, JPG, JPEG)
37
+ - `do_remove_background`: Remove background (default: true)
38
+ - `foreground_ratio`: Foreground size ratio (default: 0.85)
39
+ - `mc_resolution`: Mesh resolution (default: 256)
40
+ - `format`: Output format - "obj" or "glb" (default: "obj")
41
+ - `bake_texture_flag`: Bake texture (default: true)
42
+ - `texture_resolution`: Texture resolution (default: 2048)
43
+ - `orientation`: Mesh orientation - "standard", "gradio", or "none" (default: "standard")
44
+
45
+ **Returns:**
46
+ - ZIP file containing mesh and texture (if texture baking enabled)
47
+ - Or mesh file only (if texture baking disabled)
48
+
49
+ ### Generate 3D Model (Base64)
50
+ ```bash
51
+ POST /generate-base64
52
+ ```
53
+
54
+ Same parameters as `/generate`, but returns JSON with base64-encoded mesh and texture.
55
+
56
+ ## 🧪 Example Usage
57
+
58
+ ### cURL
59
+ ```bash
60
+ curl -X POST https://YOUR-SPACE-URL/generate \
61
+ -F "image=@your_image.png" \
62
+ -F "format=obj" \
63
+ -F "bake_texture_flag=true" \
64
+ -o output.zip
65
+ ```
66
+
67
+ ### Python
68
+ ```python
69
+ import requests
70
+
71
+ url = "https://YOUR-SPACE-URL/generate"
72
+ files = {"image": open("your_image.png", "rb")}
73
+ data = {
74
+ "format": "obj",
75
+ "bake_texture_flag": True,
76
+ "mc_resolution": 256
77
+ }
78
+
79
+ response = requests.post(url, files=files, data=data)
80
+ with open("output.zip", "wb") as f:
81
+ f.write(response.content)
82
+ ```
83
+
84
+ ### JavaScript
85
+ ```javascript
86
+ const formData = new FormData();
87
+ formData.append('image', fileInput.files[0]);
88
+ formData.append('format', 'obj');
89
+ formData.append('bake_texture_flag', 'true');
90
+
91
+ const response = await fetch('https://YOUR-SPACE-URL/generate', {
92
+ method: 'POST',
93
+ body: formData
94
+ });
95
+
96
+ const blob = await response.blob();
97
+ // Download or process the blob
98
+ ```
99
+
100
+ ## ⚙️ Configuration
101
+
102
+ ### GPU Support
103
+ This Space can run on CPU or GPU:
104
+ - **CPU**: Slower but free
105
+ - **GPU (T4)**: Much faster, may require upgrade in Space settings
106
+
107
+ To upgrade to GPU:
108
+ 1. Go to Space settings
109
+ 2. Select "T4 small" under Hardware
110
+ 3. Space will rebuild automatically
111
+
112
+ ### Performance
113
+ - **CPU**: ~30-60 seconds per image
114
+ - **GPU (T4)**: ~5-10 seconds per image
115
+ - **GPU (A100)**: ~1-2 seconds per image
116
+
117
+ ## 📚 Documentation
118
+
119
+ - [TripoSR Paper](https://arxiv.org/abs/2403.02151)
120
+ - [API Documentation](API_README.md)
121
+ - [React Integration Guide](REACT_INTEGRATION.md)
122
+
123
+ ## 🔗 Links
124
+
125
+ - [GitHub Repository](https://github.com/Ahmedbelaid/TripoSR-api)
126
+ - [Original TripoSR](https://github.com/VAST-AI-Research/TripoSR)
127
+ - [Stability AI](https://stability.ai/)
128
+ - [Tripo AI](https://www.tripo3d.ai/)
129
+
130
+ ## 📄 License
131
+
132
+ MIT License - see [LICENSE](LICENSE) file for details.
133
+
134
+ ## 🙏 Credits
135
+
136
+ - **TripoSR Model**: [Stability AI](https://stability.ai/) and [Tripo AI](https://www.tripo3d.ai/)
137
+ - **API Implementation**: Community contribution
138
+
139
+ ## 🆘 Support
140
+
141
+ For issues or questions:
142
+ - Open an issue on [GitHub](https://github.com/Ahmedbelaid/TripoSR-api/issues)
143
+ - Join the [Discord](https://discord.gg/mvS9mCfMnQ)
api_example.html ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>TripoSR API Example</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ max-width: 800px;
11
+ margin: 50px auto;
12
+ padding: 20px;
13
+ }
14
+ .container {
15
+ border: 1px solid #ddd;
16
+ padding: 20px;
17
+ border-radius: 8px;
18
+ }
19
+ input[type="file"] {
20
+ margin: 10px 0;
21
+ }
22
+ button {
23
+ background-color: #4CAF50;
24
+ color: white;
25
+ padding: 10px 20px;
26
+ border: none;
27
+ border-radius: 4px;
28
+ cursor: pointer;
29
+ font-size: 16px;
30
+ }
31
+ button:hover {
32
+ background-color: #45a049;
33
+ }
34
+ button:disabled {
35
+ background-color: #cccccc;
36
+ cursor: not-allowed;
37
+ }
38
+ .preview {
39
+ margin: 20px 0;
40
+ }
41
+ .preview img {
42
+ max-width: 100%;
43
+ border: 1px solid #ddd;
44
+ border-radius: 4px;
45
+ }
46
+ .result {
47
+ margin-top: 20px;
48
+ padding: 10px;
49
+ background-color: #f0f0f0;
50
+ border-radius: 4px;
51
+ }
52
+ .error {
53
+ color: red;
54
+ margin-top: 10px;
55
+ }
56
+ .loading {
57
+ color: #666;
58
+ font-style: italic;
59
+ }
60
+ </style>
61
+ </head>
62
+ <body>
63
+ <div class="container">
64
+ <h1>TripoSR API Example</h1>
65
+ <p>Upload an image to generate a 3D mesh</p>
66
+
67
+ <form id="meshForm">
68
+ <div>
69
+ <label for="imageInput">Select Image:</label><br>
70
+ <input type="file" id="imageInput" accept="image/*" required>
71
+ </div>
72
+
73
+ <div style="margin: 15px 0;">
74
+ <label>
75
+ <input type="checkbox" id="removeBg" checked>
76
+ Remove Background
77
+ </label>
78
+ </div>
79
+
80
+ <div style="margin: 15px 0;">
81
+ <label for="foregroundRatio">Foreground Ratio: <span id="ratioValue">0.85</span></label><br>
82
+ <input type="range" id="foregroundRatio" min="0.5" max="1.0" step="0.05" value="0.85">
83
+ </div>
84
+
85
+ <div style="margin: 15px 0;">
86
+ <label for="resolution">Resolution: <span id="resValue">256</span></label><br>
87
+ <input type="range" id="resolution" min="128" max="320" step="32" value="256">
88
+ </div>
89
+
90
+ <div style="margin: 15px 0;">
91
+ <label for="format">Format:</label>
92
+ <select id="format">
93
+ <option value="obj">OBJ</option>
94
+ <option value="glb">GLB</option>
95
+ </select>
96
+ </div>
97
+
98
+ <button type="submit" id="generateBtn">Generate Mesh</button>
99
+ </form>
100
+
101
+ <div class="preview" id="preview"></div>
102
+ <div id="result"></div>
103
+ </div>
104
+
105
+ <script>
106
+ const API_URL = 'http://localhost:8000';
107
+
108
+ // Update slider values
109
+ document.getElementById('foregroundRatio').addEventListener('input', (e) => {
110
+ document.getElementById('ratioValue').textContent = e.target.value;
111
+ });
112
+
113
+ document.getElementById('resolution').addEventListener('input', (e) => {
114
+ document.getElementById('resValue').textContent = e.target.value;
115
+ });
116
+
117
+ // Preview image
118
+ document.getElementById('imageInput').addEventListener('change', (e) => {
119
+ const file = e.target.files[0];
120
+ if (file) {
121
+ const reader = new FileReader();
122
+ reader.onload = (e) => {
123
+ const preview = document.getElementById('preview');
124
+ preview.innerHTML = `<img src="${e.target.result}" alt="Preview">`;
125
+ };
126
+ reader.readAsDataURL(file);
127
+ }
128
+ });
129
+
130
+ // Handle form submission
131
+ document.getElementById('meshForm').addEventListener('submit', async (e) => {
132
+ e.preventDefault();
133
+
134
+ const formData = new FormData();
135
+ const imageInput = document.getElementById('imageInput');
136
+ const removeBg = document.getElementById('removeBg').checked;
137
+ const foregroundRatio = parseFloat(document.getElementById('foregroundRatio').value);
138
+ const resolution = parseInt(document.getElementById('resolution').value);
139
+ const format = document.getElementById('format').value;
140
+
141
+ if (!imageInput.files[0]) {
142
+ alert('Please select an image');
143
+ return;
144
+ }
145
+
146
+ formData.append('image', imageInput.files[0]);
147
+ formData.append('do_remove_background', removeBg);
148
+ formData.append('foreground_ratio', foregroundRatio);
149
+ formData.append('mc_resolution', resolution);
150
+ formData.append('format', format);
151
+
152
+ const generateBtn = document.getElementById('generateBtn');
153
+ const resultDiv = document.getElementById('result');
154
+
155
+ generateBtn.disabled = true;
156
+ generateBtn.textContent = 'Generating...';
157
+ resultDiv.innerHTML = '<div class="loading">Processing image and generating mesh. This may take 30-60 seconds...</div>';
158
+
159
+ try {
160
+ const response = await fetch(`${API_URL}/generate`, {
161
+ method: 'POST',
162
+ body: formData
163
+ });
164
+
165
+ if (!response.ok) {
166
+ const error = await response.json();
167
+ throw new Error(error.detail || 'Failed to generate mesh');
168
+ }
169
+
170
+ // Get the mesh file
171
+ const blob = await response.blob();
172
+ const url = window.URL.createObjectURL(blob);
173
+
174
+ resultDiv.innerHTML = `
175
+ <div class="result">
176
+ <h3>Success! Mesh generated</h3>
177
+ <p>Format: ${format.toUpperCase()}</p>
178
+ <a href="${url}" download="mesh.${format}">
179
+ <button>Download Mesh</button>
180
+ </a>
181
+ </div>
182
+ `;
183
+
184
+ } catch (error) {
185
+ resultDiv.innerHTML = `<div class="error">Error: ${error.message}</div>`;
186
+ } finally {
187
+ generateBtn.disabled = false;
188
+ generateBtn.textContent = 'Generate Mesh';
189
+ }
190
+ });
191
+ </script>
192
+ </body>
193
+ </html>
194
+
api_server.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import base64
5
+ import zipfile
6
+ from io import BytesIO
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import rembg
11
+ import torch
12
+ from PIL import Image
13
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Form
14
+ from fastapi.responses import JSONResponse, FileResponse
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from pydantic import BaseModel
17
+
18
+ from tsr.system import TSR
19
+ from tsr.utils import remove_background, resize_foreground, apply_mesh_orientation
20
+ from tsr.bake_texture import bake_texture
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
25
+ )
26
+
27
+ # Initialize FastAPI app
28
+ app = FastAPI(title="TripoSR API", version="1.0.0")
29
+
30
+ # Enable CORS for frontend
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=[
34
+ "http://localhost:3000", # React default port
35
+ "http://localhost:5173", # Vite default port
36
+ "http://localhost:8080", # Vue default port
37
+ "http://127.0.0.1:3000",
38
+ "http://127.0.0.1:5173",
39
+ "http://127.0.0.1:8080",
40
+ "https://huggingface.co",
41
+ "https://*.hf.space", # Add this
42
+ # Add your production frontend URL here
43
+ # "https://your-frontend-domain.com",
44
+ ],
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
+ # Initialize model
51
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
52
+ logging.info(f"Using device: {device}")
53
+
54
+ logging.info("Loading TripoSR model...")
55
+ model = TSR.from_pretrained(
56
+ "stabilityai/TripoSR",
57
+ config_name="config.yaml",
58
+ weight_name="model.ckpt",
59
+ )
60
+ model.renderer.set_chunk_size(8192)
61
+ model.to(device)
62
+ logging.info("Model loaded successfully!")
63
+
64
+ rembg_session = rembg.new_session()
65
+
66
+
67
+ class GenerateRequest(BaseModel):
68
+ do_remove_background: bool = True
69
+ foreground_ratio: float = 0.85
70
+ mc_resolution: int = 256
71
+ format: str = "obj" # obj or glb
72
+
73
+
74
+ def preprocess_image(image: Image.Image, do_remove_background: bool, foreground_ratio: float) -> Image.Image:
75
+ """Preprocess the input image."""
76
+ def fill_background(img):
77
+ img = np.array(img).astype(np.float32) / 255.0
78
+ img = img[:, :, :3] * img[:, :, 3:4] + (1 - img[:, :, 3:4]) * 0.5
79
+ return Image.fromarray((img * 255.0).astype(np.uint8))
80
+
81
+ if do_remove_background:
82
+ image = image.convert("RGB")
83
+ image = remove_background(image, rembg_session)
84
+ image = resize_foreground(image, foreground_ratio)
85
+ image = fill_background(image)
86
+ else:
87
+ if image.mode == "RGBA":
88
+ image = fill_background(image)
89
+
90
+ return image
91
+
92
+
93
+ @app.get("/")
94
+ async def root():
95
+ return {"message": "TripoSR API is running", "device": device}
96
+
97
+
98
+ @app.get("/health")
99
+ async def health():
100
+ return {"status": "healthy", "device": device, "cuda_available": torch.cuda.is_available()}
101
+
102
+
103
+ @app.post("/generate")
104
+ async def generate_mesh(
105
+ image: UploadFile = File(...),
106
+ do_remove_background: bool = Form(True),
107
+ foreground_ratio: float = Form(0.85),
108
+ mc_resolution: int = Form(256),
109
+ format: str = Form("obj"),
110
+ bake_texture_flag: bool = Form(True),
111
+ texture_resolution: int = Form(2048),
112
+ orientation: str = Form("standard")
113
+ ):
114
+ """
115
+ Generate a 3D mesh from an uploaded image with optional texture baking.
116
+
117
+ Parameters:
118
+ - image: Image file (PNG, JPG, JPEG)
119
+ - do_remove_background: Whether to remove background (default: True)
120
+ - foreground_ratio: Ratio of foreground size (default: 0.85)
121
+ - mc_resolution: Marching cubes resolution (default: 256)
122
+ - format: Output format - "obj" or "glb" (default: "obj")
123
+ - bake_texture_flag: Whether to bake texture (default: True)
124
+ - texture_resolution: Texture atlas resolution (default: 2048)
125
+ - orientation: Mesh orientation - "standard" (Y-up, Z-forward), "gradio" (Gradio viewer), or "none" (original) (default: "standard")
126
+
127
+ Returns:
128
+ - If bake_texture=True: ZIP file containing mesh and texture.png
129
+ - If bake_texture=False: Mesh file only
130
+ """
131
+ try:
132
+ # Validate format
133
+ if format not in ["obj", "glb"]:
134
+ raise HTTPException(status_code=400, detail="Format must be 'obj' or 'glb'")
135
+
136
+ # Read and validate image
137
+ image_data = await image.read()
138
+ input_image = Image.open(BytesIO(image_data))
139
+
140
+ if input_image.mode not in ["RGB", "RGBA"]:
141
+ input_image = input_image.convert("RGB")
142
+
143
+ logging.info(f"Processing image: {image.filename}, size: {input_image.size}")
144
+
145
+ # Preprocess image
146
+ processed_image = preprocess_image(input_image, do_remove_background, foreground_ratio)
147
+
148
+ # Generate mesh
149
+ logging.info("Running model...")
150
+ with torch.no_grad():
151
+ scene_codes = model([processed_image], device=device)
152
+
153
+ # Check if xatlas is available for texture baking
154
+ xatlas_available = False
155
+ if bake_texture_flag:
156
+ try:
157
+ import xatlas
158
+ xatlas_available = True
159
+ logging.info("xatlas found, texture baking enabled")
160
+ except ImportError:
161
+ logging.warning("xatlas not available - texture baking requires xatlas. Using vertex colors instead.")
162
+ logging.warning("To enable texture baking, install: pip install xatlas==0.0.9 moderngl==5.10.0")
163
+ bake_texture_flag = False
164
+
165
+ # ALWAYS extract mesh with vertex colors (colors from the model trained on the image)
166
+ # has_vertex_color=True extracts colors that the model learned from your input image
167
+ # These colors represent the true colors from your image as interpreted by the trained model
168
+ logging.info("Extracting mesh with vertex colors (true colors from image)...")
169
+ meshes = model.extract_mesh(scene_codes, has_vertex_color=True, resolution=mc_resolution)
170
+ mesh = meshes[0]
171
+
172
+ # Apply orientation transformation for better 3D viewing
173
+ # Options: "standard" (Y-up, Z-forward), "gradio" (Gradio viewer), "none" (original)
174
+ if orientation not in ["standard", "gradio", "none"]:
175
+ orientation = "standard"
176
+ logging.info(f"Applying mesh orientation: {orientation}")
177
+ mesh = apply_mesh_orientation(mesh, orientation=orientation)
178
+
179
+ if bake_texture_flag and xatlas_available:
180
+ # Bake texture
181
+ logging.info("Baking texture...")
182
+ bake_output = bake_texture(mesh, model, scene_codes[0], texture_resolution)
183
+
184
+ # Save mesh with UV mapping
185
+ mesh_temp = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
186
+ xatlas.export(
187
+ mesh_temp.name,
188
+ mesh.vertices[bake_output["vmapping"]],
189
+ bake_output["indices"],
190
+ bake_output["uvs"],
191
+ mesh.vertex_normals[bake_output["vmapping"]]
192
+ )
193
+ mesh_temp.close()
194
+
195
+ # Save texture
196
+ texture_temp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
197
+ Image.fromarray(
198
+ (bake_output["colors"] * 255.0).astype(np.uint8)
199
+ ).transpose(Image.FLIP_TOP_BOTTOM).save(texture_temp.name)
200
+ texture_temp.close()
201
+
202
+ # Create ZIP file with both mesh and texture
203
+ zip_temp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
204
+ with zipfile.ZipFile(zip_temp.name, 'w', zipfile.ZIP_DEFLATED) as zipf:
205
+ zipf.write(mesh_temp.name, f"mesh.{format}")
206
+ zipf.write(texture_temp.name, "texture.png")
207
+ zip_temp.close()
208
+
209
+ logging.info(f"Mesh and texture saved to: {zip_temp.name}")
210
+
211
+ # Clean up individual files
212
+ os.unlink(mesh_temp.name)
213
+ os.unlink(texture_temp.name)
214
+
215
+ return FileResponse(
216
+ zip_temp.name,
217
+ media_type="application/zip",
218
+ filename="mesh_with_texture.zip",
219
+ headers={"X-File-Path": zip_temp.name}
220
+ )
221
+ else:
222
+ # Save mesh with vertex colors (colors from the model, which learned from your image)
223
+ # The vertex colors represent the true colors from the input image as interpreted by the model
224
+ temp_file = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
225
+ try:
226
+ mesh.export(temp_file.name)
227
+ except AttributeError as e:
228
+ if "ptp" in str(e) and format == "glb":
229
+ # Fallback to OBJ if GLB export fails due to NumPy 2.0 compatibility
230
+ logging.warning(f"GLB export failed due to NumPy compatibility, falling back to OBJ format")
231
+ temp_file.close()
232
+ os.unlink(temp_file.name)
233
+ temp_file = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
234
+ format = "obj"
235
+ mesh.export(temp_file.name)
236
+ else:
237
+ raise
238
+ temp_file.close()
239
+
240
+ logging.info(f"Mesh with vertex colors (from image) saved to: {temp_file.name}")
241
+
242
+ return FileResponse(
243
+ temp_file.name,
244
+ media_type="application/octet-stream",
245
+ filename=f"mesh.{format}",
246
+ headers={"X-File-Path": temp_file.name}
247
+ )
248
+
249
+ except Exception as e:
250
+ logging.error(f"Error generating mesh: {str(e)}", exc_info=True)
251
+ raise HTTPException(status_code=500, detail=f"Error generating mesh: {str(e)}")
252
+
253
+
254
+ @app.post("/generate-base64")
255
+ async def generate_mesh_base64(
256
+ image: UploadFile = File(...),
257
+ do_remove_background: bool = Form(True),
258
+ foreground_ratio: float = Form(0.85),
259
+ mc_resolution: int = Form(256),
260
+ format: str = Form("obj"),
261
+ bake_texture_flag: bool = Form(True),
262
+ texture_resolution: int = Form(2048),
263
+ orientation: str = Form("standard")
264
+ ):
265
+ """
266
+ Generate a 3D mesh with texture and return as base64 encoded strings.
267
+ Useful for frontend that wants to handle the mesh data directly.
268
+
269
+ Returns JSON with mesh and texture (if baked) as base64 strings.
270
+ """
271
+ try:
272
+ if format not in ["obj", "glb"]:
273
+ raise HTTPException(status_code=400, detail="Format must be 'obj' or 'glb'")
274
+
275
+ # Read and validate image
276
+ image_data = await image.read()
277
+ input_image = Image.open(BytesIO(image_data))
278
+
279
+ if input_image.mode not in ["RGB", "RGBA"]:
280
+ input_image = input_image.convert("RGB")
281
+
282
+ logging.info(f"Processing image: {image.filename}")
283
+
284
+ # Preprocess image
285
+ processed_image = preprocess_image(input_image, do_remove_background, foreground_ratio)
286
+
287
+ # Generate mesh
288
+ logging.info("Running model...")
289
+ with torch.no_grad():
290
+ scene_codes = model([processed_image], device=device)
291
+
292
+ # Check if xatlas is available for texture baking
293
+ xatlas_available = False
294
+ if bake_texture_flag:
295
+ try:
296
+ import xatlas
297
+ xatlas_available = True
298
+ logging.info("xatlas found, texture baking enabled")
299
+ except ImportError:
300
+ logging.warning("xatlas not available - texture baking requires xatlas. Using vertex colors instead.")
301
+ bake_texture_flag = False
302
+
303
+ # ALWAYS extract mesh with vertex colors (colors from the model trained on the image)
304
+ # has_vertex_color=True extracts colors that the model learned from your input image
305
+ logging.info("Extracting mesh with vertex colors (true colors from image)...")
306
+ meshes = model.extract_mesh(scene_codes, has_vertex_color=True, resolution=mc_resolution)
307
+ mesh = meshes[0]
308
+
309
+ # Apply orientation transformation for better 3D viewing
310
+ # Options: "standard" (Y-up, Z-forward), "gradio" (Gradio viewer), "none" (original)
311
+ if orientation not in ["standard", "gradio", "none"]:
312
+ orientation = "standard"
313
+ logging.info(f"Applying mesh orientation: {orientation}")
314
+ mesh = apply_mesh_orientation(mesh, orientation=orientation)
315
+
316
+ if bake_texture_flag and xatlas_available:
317
+ # Bake texture (creates texture atlas from model colors)
318
+ logging.info("Baking texture...")
319
+ bake_output = bake_texture(mesh, model, scene_codes[0], texture_resolution)
320
+
321
+ # Save mesh with UV mapping
322
+ mesh_temp = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
323
+ xatlas.export(
324
+ mesh_temp.name,
325
+ mesh.vertices[bake_output["vmapping"]],
326
+ bake_output["indices"],
327
+ bake_output["uvs"],
328
+ mesh.vertex_normals[bake_output["vmapping"]]
329
+ )
330
+
331
+ # Save texture
332
+ texture_temp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
333
+ Image.fromarray(
334
+ (bake_output["colors"] * 255.0).astype(np.uint8)
335
+ ).transpose(Image.FLIP_TOP_BOTTOM).save(texture_temp.name)
336
+
337
+ # Read and encode to base64
338
+ with open(mesh_temp.name, "rb") as f:
339
+ mesh_data = f.read()
340
+ with open(texture_temp.name, "rb") as f:
341
+ texture_data = f.read()
342
+
343
+ mesh_base64 = base64.b64encode(mesh_data).decode("utf-8")
344
+ texture_base64 = base64.b64encode(texture_data).decode("utf-8")
345
+
346
+ # Clean up
347
+ os.unlink(mesh_temp.name)
348
+ os.unlink(texture_temp.name)
349
+
350
+ return JSONResponse({
351
+ "success": True,
352
+ "format": format,
353
+ "mesh": mesh_base64,
354
+ "texture": texture_base64,
355
+ "mesh_size": len(mesh_data),
356
+ "texture_size": len(texture_data),
357
+ "has_texture": True
358
+ })
359
+ else:
360
+ # Save mesh without texture
361
+ temp_file = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
362
+ try:
363
+ mesh.export(temp_file.name)
364
+ except AttributeError as e:
365
+ if "ptp" in str(e) and format == "glb":
366
+ # Fallback to OBJ if GLB export fails due to NumPy 2.0 compatibility
367
+ logging.warning(f"GLB export failed due to NumPy compatibility, falling back to OBJ format")
368
+ temp_file.close()
369
+ os.unlink(temp_file.name)
370
+ temp_file = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
371
+ format = "obj"
372
+ mesh.export(temp_file.name)
373
+ else:
374
+ raise
375
+
376
+ # Read file and encode to base64
377
+ with open(temp_file.name, "rb") as f:
378
+ mesh_data = f.read()
379
+
380
+ mesh_base64 = base64.b64encode(mesh_data).decode("utf-8")
381
+
382
+ # Clean up
383
+ os.unlink(temp_file.name)
384
+
385
+ return JSONResponse({
386
+ "success": True,
387
+ "format": format,
388
+ "mesh": mesh_base64,
389
+ "mesh_size": len(mesh_data),
390
+ "has_texture": False
391
+ })
392
+
393
+ except Exception as e:
394
+ logging.error(f"Error generating mesh: {str(e)}", exc_info=True)
395
+ raise HTTPException(status_code=500, detail=f"Error generating mesh: {str(e)}")
396
+
397
+
398
+ if __name__ == "__main__":
399
+ import uvicorn
400
+ # Use 127.0.0.1 for localhost access, or 0.0.0.0 for network access
401
+ # For local development, use 127.0.0.1
402
+ uvicorn.run(app, host="127.0.0.1", port=8000, reload=False)
403
+
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import rembg
8
+ import torch
9
+ from PIL import Image
10
+ from functools import partial
11
+
12
+ from tsr.system import TSR
13
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
14
+
15
+
16
+ if torch.cuda.is_available():
17
+ device = "cuda:0"
18
+ else:
19
+ device = "cpu"
20
+
21
+ model = TSR.from_pretrained(
22
+ "stabilityai/TripoSR",
23
+ config_name="config.yaml",
24
+ weight_name="model.ckpt",
25
+ )
26
+
27
+ # adjust the chunk size to balance between speed and memory usage
28
+ model.renderer.set_chunk_size(8192)
29
+ model.to(device)
30
+
31
+ rembg_session = rembg.new_session()
32
+
33
+
34
+ def check_input_image(input_image):
35
+ if input_image is None:
36
+ raise gr.Error("No image uploaded!")
37
+
38
+
39
+ def preprocess(input_image, do_remove_background, foreground_ratio):
40
+ def fill_background(image):
41
+ image = np.array(image).astype(np.float32) / 255.0
42
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
43
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
44
+ return image
45
+
46
+ if do_remove_background:
47
+ image = input_image.convert("RGB")
48
+ image = remove_background(image, rembg_session)
49
+ image = resize_foreground(image, foreground_ratio)
50
+ image = fill_background(image)
51
+ else:
52
+ image = input_image
53
+ if image.mode == "RGBA":
54
+ image = fill_background(image)
55
+ return image
56
+
57
+
58
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
59
+ scene_codes = model(image, device=device)
60
+ mesh = model.extract_mesh(scene_codes, True, resolution=mc_resolution)[0]
61
+ mesh = to_gradio_3d_orientation(mesh)
62
+ rv = []
63
+ for format in formats:
64
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
65
+ mesh.export(mesh_path.name)
66
+ rv.append(mesh_path.name)
67
+ return rv
68
+
69
+
70
+ def run_example(image_pil):
71
+ preprocessed = preprocess(image_pil, False, 0.9)
72
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
73
+ return preprocessed, mesh_name_obj, mesh_name_glb
74
+
75
+
76
+ with gr.Blocks(title="TripoSR") as interface:
77
+ gr.Markdown(
78
+ """
79
+ # TripoSR Demo
80
+ [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
81
+
82
+ **Tips:**
83
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
84
+ 2. It's better to disable "Remove Background" for the provided examples (except fot the last one) since they have been already preprocessed.
85
+ 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
86
+ """
87
+ )
88
+ with gr.Row(variant="panel"):
89
+ with gr.Column():
90
+ with gr.Row():
91
+ input_image = gr.Image(
92
+ label="Input Image",
93
+ image_mode="RGBA",
94
+ sources="upload",
95
+ type="pil",
96
+ elem_id="content_image",
97
+ )
98
+ processed_image = gr.Image(label="Processed Image", interactive=False)
99
+ with gr.Row():
100
+ with gr.Group():
101
+ do_remove_background = gr.Checkbox(
102
+ label="Remove Background", value=True
103
+ )
104
+ foreground_ratio = gr.Slider(
105
+ label="Foreground Ratio",
106
+ minimum=0.5,
107
+ maximum=1.0,
108
+ value=0.85,
109
+ step=0.05,
110
+ )
111
+ mc_resolution = gr.Slider(
112
+ label="Marching Cubes Resolution",
113
+ minimum=32,
114
+ maximum=320,
115
+ value=256,
116
+ step=32
117
+ )
118
+ with gr.Row():
119
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
120
+ with gr.Column():
121
+ with gr.Tab("OBJ"):
122
+ output_model_obj = gr.Model3D(
123
+ label="Output Model (OBJ Format)",
124
+ interactive=False,
125
+ )
126
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
127
+ with gr.Tab("GLB"):
128
+ output_model_glb = gr.Model3D(
129
+ label="Output Model (GLB Format)",
130
+ interactive=False,
131
+ )
132
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
133
+ with gr.Row(variant="panel"):
134
+ gr.Examples(
135
+ examples=[
136
+ "examples/hamburger.png",
137
+ "examples/poly_fox.png",
138
+ "examples/robot.png",
139
+ "examples/teapot.png",
140
+ "examples/tiger_girl.png",
141
+ "examples/horse.png",
142
+ "examples/flamingo.png",
143
+ "examples/unicorn.png",
144
+ "examples/chair.png",
145
+ "examples/iso_house.png",
146
+ "examples/marble.png",
147
+ "examples/police_woman.png",
148
+ "examples/captured.jpeg",
149
+ ],
150
+ inputs=[input_image],
151
+ outputs=[processed_image, output_model_obj, output_model_glb],
152
+ cache_examples=False,
153
+ fn=partial(run_example),
154
+ label="Examples",
155
+ examples_per_page=20,
156
+ )
157
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
158
+ fn=preprocess,
159
+ inputs=[input_image, do_remove_background, foreground_ratio],
160
+ outputs=[processed_image],
161
+ ).success(
162
+ fn=generate,
163
+ inputs=[processed_image, mc_resolution],
164
+ outputs=[output_model_obj, output_model_glb],
165
+ )
166
+
167
+
168
+ # For Hugging Face Spaces, we just need to assign the interface
169
+ # The launch() is handled automatically by Spaces
170
+ app = interface
171
+
deploy_colab.ipynb ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "# 🎨 TripoSR API on Google Colab\n",
23
+ "\n",
24
+ "This notebook deploys the TripoSR API on Google Colab with a public URL using ngrok.\n",
25
+ "\n",
26
+ "## 📋 Setup Instructions\n",
27
+ "\n",
28
+ "1. **Enable GPU**: Runtime → Change runtime type → GPU (T4)\n",
29
+ "2. **Get ngrok token**: Sign up at [ngrok.com](https://ngrok.com) and get your authtoken\n",
30
+ "3. **Run all cells** in order\n",
31
+ "4. **Copy the public URL** from the output\n",
32
+ "\n",
33
+ "⚠️ **Note**: The session will expire after 12 hours of inactivity or 24 hours maximum."
34
+ ],
35
+ "metadata": {
36
+ "id": "header"
37
+ }
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "source": [
42
+ "# Check GPU availability\n",
43
+ "!nvidia-smi"
44
+ ],
45
+ "metadata": {
46
+ "id": "check_gpu"
47
+ },
48
+ "execution_count": null,
49
+ "outputs": []
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "source": [
54
+ "# Clone the repository\n",
55
+ "!git clone https://github.com/Ahmedbelaid/TripoSR-api.git\n",
56
+ "%cd TripoSR-api"
57
+ ],
58
+ "metadata": {
59
+ "id": "clone_repo"
60
+ },
61
+ "execution_count": null,
62
+ "outputs": []
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "source": [
67
+ "# Install dependencies\n",
68
+ "!pip install -q torch torchvision\n",
69
+ "!pip install -q -r requirements.txt"
70
+ ],
71
+ "metadata": {
72
+ "id": "install_deps"
73
+ },
74
+ "execution_count": null,
75
+ "outputs": []
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "source": [
80
+ "# Install ngrok for public URL\n",
81
+ "!pip install -q pyngrok\n",
82
+ "\n",
83
+ "from pyngrok import ngrok, conf\n",
84
+ "import getpass\n",
85
+ "\n",
86
+ "# Get ngrok authtoken\n",
87
+ "print(\"Get your authtoken from: https://dashboard.ngrok.com/get-started/your-authtoken\")\n",
88
+ "authtoken = getpass.getpass(\"Enter your ngrok authtoken: \")\n",
89
+ "\n",
90
+ "# Set ngrok authtoken\n",
91
+ "conf.get_default().auth_token = authtoken"
92
+ ],
93
+ "metadata": {
94
+ "id": "setup_ngrok"
95
+ },
96
+ "execution_count": null,
97
+ "outputs": []
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "source": [
102
+ "# Start the API server in the background\n",
103
+ "import subprocess\n",
104
+ "import time\n",
105
+ "\n",
106
+ "# Start server\n",
107
+ "process = subprocess.Popen(\n",
108
+ " [\"python\", \"-m\", \"uvicorn\", \"api_server:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8000\"],\n",
109
+ " stdout=subprocess.PIPE,\n",
110
+ " stderr=subprocess.PIPE\n",
111
+ ")\n",
112
+ "\n",
113
+ "# Wait for server to start\n",
114
+ "print(\"Starting server...\")\n",
115
+ "time.sleep(10)\n",
116
+ "\n",
117
+ "# Create ngrok tunnel\n",
118
+ "public_url = ngrok.connect(8000)\n",
119
+ "\n",
120
+ "print(\"\\n\" + \"=\"*60)\n",
121
+ "print(\"🚀 TripoSR API is now running!\")\n",
122
+ "print(\"=\"*60)\n",
123
+ "print(f\"\\n📡 Public URL: {public_url}\")\n",
124
+ "print(f\"\\n🔍 Health Check: {public_url}/health\")\n",
125
+ "print(f\"\\n📝 API Docs: {public_url}/docs\")\n",
126
+ "print(\"\\n\" + \"=\"*60)\n",
127
+ "print(\"\\n⚠️ Keep this notebook running to keep the API active\")\n",
128
+ "print(\"⚠️ Session will expire after 12 hours of inactivity\\n\")"
129
+ ],
130
+ "metadata": {
131
+ "id": "start_server"
132
+ },
133
+ "execution_count": null,
134
+ "outputs": []
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "source": [
139
+ "# Test the API\n",
140
+ "import requests\n",
141
+ "\n",
142
+ "# Health check\n",
143
+ "response = requests.get(f\"{public_url}/health\")\n",
144
+ "print(\"Health Check Response:\")\n",
145
+ "print(response.json())"
146
+ ],
147
+ "metadata": {
148
+ "id": "test_api"
149
+ },
150
+ "execution_count": null,
151
+ "outputs": []
152
+ },
153
+ {
154
+ "cell_type": "markdown",
155
+ "source": [
156
+ "## 🧪 Example: Generate 3D Model\n",
157
+ "\n",
158
+ "Upload an image and generate a 3D model:"
159
+ ],
160
+ "metadata": {
161
+ "id": "example_header"
162
+ }
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "source": [
167
+ "from google.colab import files\n",
168
+ "import requests\n",
169
+ "\n",
170
+ "# Upload an image\n",
171
+ "print(\"Upload an image file:\")\n",
172
+ "uploaded = files.upload()\n",
173
+ "\n",
174
+ "# Get the filename\n",
175
+ "filename = list(uploaded.keys())[0]\n",
176
+ "\n",
177
+ "# Generate 3D model\n",
178
+ "print(f\"\\nGenerating 3D model from {filename}...\")\n",
179
+ "with open(filename, 'rb') as f:\n",
180
+ " files_dict = {'image': f}\n",
181
+ " data = {\n",
182
+ " 'format': 'obj',\n",
183
+ " 'bake_texture_flag': True,\n",
184
+ " 'mc_resolution': 256\n",
185
+ " }\n",
186
+ " response = requests.post(f\"{public_url}/generate\", files=files_dict, data=data)\n",
187
+ "\n",
188
+ "if response.status_code == 200:\n",
189
+ " # Save the output\n",
190
+ " output_filename = 'output.zip'\n",
191
+ " with open(output_filename, 'wb') as f:\n",
192
+ " f.write(response.content)\n",
193
+ " print(f\"\\n✅ Success! 3D model saved to {output_filename}\")\n",
194
+ " \n",
195
+ " # Download the file\n",
196
+ " files.download(output_filename)\n",
197
+ "else:\n",
198
+ " print(f\"\\n❌ Error: {response.status_code}\")\n",
199
+ " print(response.text)"
200
+ ],
201
+ "metadata": {
202
+ "id": "generate_example"
203
+ },
204
+ "execution_count": null,
205
+ "outputs": []
206
+ },
207
+ {
208
+ "cell_type": "markdown",
209
+ "source": [
210
+ "## 📊 Monitor Server Logs\n",
211
+ "\n",
212
+ "Run this cell to see server logs:"
213
+ ],
214
+ "metadata": {
215
+ "id": "logs_header"
216
+ }
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "source": [
221
+ "# View server logs\n",
222
+ "import time\n",
223
+ "\n",
224
+ "print(\"Server logs (press Stop to exit):\")\n",
225
+ "print(\"=\"*60)\n",
226
+ "\n",
227
+ "while True:\n",
228
+ " output = process.stdout.readline()\n",
229
+ " if output:\n",
230
+ " print(output.decode().strip())\n",
231
+ " error = process.stderr.readline()\n",
232
+ " if error:\n",
233
+ " print(error.decode().strip())\n",
234
+ " time.sleep(0.1)"
235
+ ],
236
+ "metadata": {
237
+ "id": "view_logs"
238
+ },
239
+ "execution_count": null,
240
+ "outputs": []
241
+ },
242
+ {
243
+ "cell_type": "markdown",
244
+ "source": [
245
+ "## 🛑 Stop Server\n",
246
+ "\n",
247
+ "Run this cell to stop the server:"
248
+ ],
249
+ "metadata": {
250
+ "id": "stop_header"
251
+ }
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "source": [
256
+ "# Stop the server\n",
257
+ "process.terminate()\n",
258
+ "ngrok.disconnect(public_url)\n",
259
+ "print(\"Server stopped.\")"
260
+ ],
261
+ "metadata": {
262
+ "id": "stop_server"
263
+ },
264
+ "execution_count": null,
265
+ "outputs": []
266
+ }
267
+ ]
268
+ }
docker-compose.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ triposr-api:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ container_name: triposr-api
9
+ ports:
10
+ - "8000:8000"
11
+ environment:
12
+ - CUDA_VISIBLE_DEVICES=0 # Use first GPU, set to empty string to use CPU
13
+ volumes:
14
+ # Optional: Mount a volume for output files if you want to persist them
15
+ - ./output:/app/output
16
+ # Uncomment the following lines if you have NVIDIA GPU and want to use it
17
+ # deploy:
18
+ # resources:
19
+ # reservations:
20
+ # devices:
21
+ # - driver: nvidia
22
+ # count: 1
23
+ # capabilities: [gpu]
24
+ restart: unless-stopped
25
+ healthcheck:
26
+ test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
27
+ interval: 30s
28
+ timeout: 10s
29
+ retries: 3
30
+ start_period: 60s
31
+
gradio_app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tempfile
4
+ import time
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import rembg
9
+ import torch
10
+ from PIL import Image
11
+ from functools import partial
12
+
13
+ from tsr.system import TSR
14
+ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
15
+
16
+ import argparse
17
+
18
+
19
+ if torch.cuda.is_available():
20
+ device = "cuda:0"
21
+ else:
22
+ device = "cpu"
23
+
24
+ model = TSR.from_pretrained(
25
+ "stabilityai/TripoSR",
26
+ config_name="config.yaml",
27
+ weight_name="model.ckpt",
28
+ )
29
+
30
+ # adjust the chunk size to balance between speed and memory usage
31
+ model.renderer.set_chunk_size(8192)
32
+ model.to(device)
33
+
34
+ rembg_session = rembg.new_session()
35
+
36
+
37
+ def check_input_image(input_image):
38
+ if input_image is None:
39
+ raise gr.Error("No image uploaded!")
40
+
41
+
42
+ def preprocess(input_image, do_remove_background, foreground_ratio):
43
+ def fill_background(image):
44
+ image = np.array(image).astype(np.float32) / 255.0
45
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
46
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
47
+ return image
48
+
49
+ if do_remove_background:
50
+ image = input_image.convert("RGB")
51
+ image = remove_background(image, rembg_session)
52
+ image = resize_foreground(image, foreground_ratio)
53
+ image = fill_background(image)
54
+ else:
55
+ image = input_image
56
+ if image.mode == "RGBA":
57
+ image = fill_background(image)
58
+ return image
59
+
60
+
61
+ def generate(image, mc_resolution, formats=["obj", "glb"]):
62
+ scene_codes = model(image, device=device)
63
+ mesh = model.extract_mesh(scene_codes, True, resolution=mc_resolution)[0]
64
+ mesh = to_gradio_3d_orientation(mesh)
65
+ rv = []
66
+ for format in formats:
67
+ mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
68
+ mesh.export(mesh_path.name)
69
+ rv.append(mesh_path.name)
70
+ return rv
71
+
72
+
73
+ def run_example(image_pil):
74
+ preprocessed = preprocess(image_pil, False, 0.9)
75
+ mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
76
+ return preprocessed, mesh_name_obj, mesh_name_glb
77
+
78
+
79
+ with gr.Blocks(title="TripoSR") as interface:
80
+ gr.Markdown(
81
+ """
82
+ # TripoSR Demo
83
+ [TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
84
+
85
+ **Tips:**
86
+ 1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
87
+ 2. It's better to disable "Remove Background" for the provided examples (except fot the last one) since they have been already preprocessed.
88
+ 3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
89
+ """
90
+ )
91
+ with gr.Row(variant="panel"):
92
+ with gr.Column():
93
+ with gr.Row():
94
+ input_image = gr.Image(
95
+ label="Input Image",
96
+ image_mode="RGBA",
97
+ sources="upload",
98
+ type="pil",
99
+ elem_id="content_image",
100
+ )
101
+ processed_image = gr.Image(label="Processed Image", interactive=False)
102
+ with gr.Row():
103
+ with gr.Group():
104
+ do_remove_background = gr.Checkbox(
105
+ label="Remove Background", value=True
106
+ )
107
+ foreground_ratio = gr.Slider(
108
+ label="Foreground Ratio",
109
+ minimum=0.5,
110
+ maximum=1.0,
111
+ value=0.85,
112
+ step=0.05,
113
+ )
114
+ mc_resolution = gr.Slider(
115
+ label="Marching Cubes Resolution",
116
+ minimum=32,
117
+ maximum=320,
118
+ value=256,
119
+ step=32
120
+ )
121
+ with gr.Row():
122
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
123
+ with gr.Column():
124
+ with gr.Tab("OBJ"):
125
+ output_model_obj = gr.Model3D(
126
+ label="Output Model (OBJ Format)",
127
+ interactive=False,
128
+ )
129
+ gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
130
+ with gr.Tab("GLB"):
131
+ output_model_glb = gr.Model3D(
132
+ label="Output Model (GLB Format)",
133
+ interactive=False,
134
+ )
135
+ gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
136
+ with gr.Row(variant="panel"):
137
+ gr.Examples(
138
+ examples=[
139
+ "examples/hamburger.png",
140
+ "examples/poly_fox.png",
141
+ "examples/robot.png",
142
+ "examples/teapot.png",
143
+ "examples/tiger_girl.png",
144
+ "examples/horse.png",
145
+ "examples/flamingo.png",
146
+ "examples/unicorn.png",
147
+ "examples/chair.png",
148
+ "examples/iso_house.png",
149
+ "examples/marble.png",
150
+ "examples/police_woman.png",
151
+ "examples/captured.jpeg",
152
+ ],
153
+ inputs=[input_image],
154
+ outputs=[processed_image, output_model_obj, output_model_glb],
155
+ cache_examples=False,
156
+ fn=partial(run_example),
157
+ label="Examples",
158
+ examples_per_page=20,
159
+ )
160
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
161
+ fn=preprocess,
162
+ inputs=[input_image, do_remove_background, foreground_ratio],
163
+ outputs=[processed_image],
164
+ ).success(
165
+ fn=generate,
166
+ inputs=[processed_image, mc_resolution],
167
+ outputs=[output_model_obj, output_model_glb],
168
+ )
169
+
170
+
171
+
172
+ if __name__ == '__main__':
173
+ parser = argparse.ArgumentParser()
174
+ parser.add_argument('--username', type=str, default=None, help='Username for authentication')
175
+ parser.add_argument('--password', type=str, default=None, help='Password for authentication')
176
+ parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
177
+ parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
178
+ parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
179
+ parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
180
+ args = parser.parse_args()
181
+ interface.queue(max_size=args.queuesize)
182
+ interface.launch(
183
+ auth=(args.username, args.password) if (args.username and args.password) else None,
184
+ share=args.share,
185
+ server_name="0.0.0.0" if args.listen else None,
186
+ server_port=args.port
187
+ )
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ omegaconf==2.3.0
3
+ Pillow==10.1.0
4
+ einops==0.7.0
5
+ git+https://github.com/tatsy/torchmcubes.git
6
+ transformers==4.35.0
7
+ trimesh==4.0.5
8
+ rembg
9
+ huggingface-hub
10
+ imageio[ffmpeg]
11
+ gradio==3.50.2
12
+ xatlas==0.0.9
13
+ moderngl==5.10.0
14
+ fastapi>=0.100.0
15
+ uvicorn[standard]>=0.23.0
16
+ python-multipart
run.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import rembg
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from tsr.system import TSR
12
+ from tsr.utils import remove_background, resize_foreground, save_video
13
+ from tsr.bake_texture import bake_texture
14
+
15
+
16
+ class Timer:
17
+ def __init__(self):
18
+ self.items = {}
19
+ self.time_scale = 1000.0 # ms
20
+ self.time_unit = "ms"
21
+
22
+ def start(self, name: str) -> None:
23
+ if torch.cuda.is_available():
24
+ torch.cuda.synchronize()
25
+ self.items[name] = time.time()
26
+ logging.info(f"{name} ...")
27
+
28
+ def end(self, name: str) -> float:
29
+ if name not in self.items:
30
+ return
31
+ if torch.cuda.is_available():
32
+ torch.cuda.synchronize()
33
+ start_time = self.items.pop(name)
34
+ delta = time.time() - start_time
35
+ t = delta * self.time_scale
36
+ logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
37
+
38
+
39
+ timer = Timer()
40
+
41
+
42
+ logging.basicConfig(
43
+ format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
44
+ )
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("image", type=str, nargs="+", help="Path to input image(s).")
47
+ parser.add_argument(
48
+ "--device",
49
+ default="cuda:0",
50
+ type=str,
51
+ help="Device to use. If no CUDA-compatible device is found, will fallback to 'cpu'. Default: 'cuda:0'",
52
+ )
53
+ parser.add_argument(
54
+ "--pretrained-model-name-or-path",
55
+ default="stabilityai/TripoSR",
56
+ type=str,
57
+ help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/TripoSR'",
58
+ )
59
+ parser.add_argument(
60
+ "--chunk-size",
61
+ default=8192,
62
+ type=int,
63
+ help="Evaluation chunk size for surface extraction and rendering. Smaller chunk size reduces VRAM usage but increases computation time. 0 for no chunking. Default: 8192",
64
+ )
65
+ parser.add_argument(
66
+ "--mc-resolution",
67
+ default=256,
68
+ type=int,
69
+ help="Marching cubes grid resolution. Default: 256"
70
+ )
71
+ parser.add_argument(
72
+ "--no-remove-bg",
73
+ action="store_true",
74
+ help="If specified, the background will NOT be automatically removed from the input image, and the input image should be an RGB image with gray background and properly-sized foreground. Default: false",
75
+ )
76
+ parser.add_argument(
77
+ "--foreground-ratio",
78
+ default=0.85,
79
+ type=float,
80
+ help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
81
+ )
82
+ parser.add_argument(
83
+ "--output-dir",
84
+ default="output/",
85
+ type=str,
86
+ help="Output directory to save the results. Default: 'output/'",
87
+ )
88
+ parser.add_argument(
89
+ "--model-save-format",
90
+ default="obj",
91
+ type=str,
92
+ choices=["obj", "glb"],
93
+ help="Format to save the extracted mesh. Default: 'obj'",
94
+ )
95
+ parser.add_argument(
96
+ "--bake-texture",
97
+ action="store_true",
98
+ help="Bake a texture atlas for the extracted mesh, instead of vertex colors",
99
+ )
100
+ parser.add_argument(
101
+ "--texture-resolution",
102
+ default=2048,
103
+ type=int,
104
+ help="Texture atlas resolution, only useful with --bake-texture. Default: 2048"
105
+ )
106
+ parser.add_argument(
107
+ "--render",
108
+ action="store_true",
109
+ help="If specified, save a NeRF-rendered video. Default: false",
110
+ )
111
+ args = parser.parse_args()
112
+
113
+ output_dir = args.output_dir
114
+ os.makedirs(output_dir, exist_ok=True)
115
+
116
+ device = args.device
117
+ if not torch.cuda.is_available():
118
+ device = "cpu"
119
+
120
+ timer.start("Initializing model")
121
+ model = TSR.from_pretrained(
122
+ args.pretrained_model_name_or_path,
123
+ config_name="config.yaml",
124
+ weight_name="model.ckpt",
125
+ )
126
+ model.renderer.set_chunk_size(args.chunk_size)
127
+ model.to(device)
128
+ timer.end("Initializing model")
129
+
130
+ timer.start("Processing images")
131
+ images = []
132
+
133
+ if args.no_remove_bg:
134
+ rembg_session = None
135
+ else:
136
+ rembg_session = rembg.new_session()
137
+
138
+ for i, image_path in enumerate(args.image):
139
+ if args.no_remove_bg:
140
+ image = np.array(Image.open(image_path).convert("RGB"))
141
+ else:
142
+ image = remove_background(Image.open(image_path), rembg_session)
143
+ image = resize_foreground(image, args.foreground_ratio)
144
+ image = np.array(image).astype(np.float32) / 255.0
145
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
146
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
147
+ if not os.path.exists(os.path.join(output_dir, str(i))):
148
+ os.makedirs(os.path.join(output_dir, str(i)))
149
+ image.save(os.path.join(output_dir, str(i), f"input.png"))
150
+ images.append(image)
151
+ timer.end("Processing images")
152
+
153
+ for i, image in enumerate(images):
154
+ logging.info(f"Running image {i + 1}/{len(images)} ...")
155
+
156
+ timer.start("Running model")
157
+ with torch.no_grad():
158
+ scene_codes = model([image], device=device)
159
+ timer.end("Running model")
160
+
161
+ if args.render:
162
+ timer.start("Rendering")
163
+ render_images = model.render(scene_codes, n_views=30, return_type="pil")
164
+ for ri, render_image in enumerate(render_images[0]):
165
+ render_image.save(os.path.join(output_dir, str(i), f"render_{ri:03d}.png"))
166
+ save_video(
167
+ render_images[0], os.path.join(output_dir, str(i), f"render.mp4"), fps=30
168
+ )
169
+ timer.end("Rendering")
170
+
171
+ timer.start("Extracting mesh")
172
+ meshes = model.extract_mesh(scene_codes, not args.bake_texture, resolution=args.mc_resolution)
173
+ timer.end("Extracting mesh")
174
+
175
+ out_mesh_path = os.path.join(output_dir, str(i), f"mesh.{args.model_save_format}")
176
+ if args.bake_texture:
177
+ try:
178
+ import xatlas
179
+ except ImportError:
180
+ raise ImportError(
181
+ "xatlas is required for texture baking. Please install it with: pip install xatlas==0.0.9\n"
182
+ "Note: This requires Microsoft Visual C++ Build Tools to compile."
183
+ )
184
+ out_texture_path = os.path.join(output_dir, str(i), "texture.png")
185
+
186
+ timer.start("Baking texture")
187
+ bake_output = bake_texture(meshes[0], model, scene_codes[0], args.texture_resolution)
188
+ timer.end("Baking texture")
189
+
190
+ timer.start("Exporting mesh and texture")
191
+ xatlas.export(out_mesh_path, meshes[0].vertices[bake_output["vmapping"]], bake_output["indices"], bake_output["uvs"], meshes[0].vertex_normals[bake_output["vmapping"]])
192
+ Image.fromarray((bake_output["colors"] * 255.0).astype(np.uint8)).transpose(Image.FLIP_TOP_BOTTOM).save(out_texture_path)
193
+ timer.end("Exporting mesh and texture")
194
+ else:
195
+ timer.start("Exporting mesh")
196
+ meshes[0].export(out_mesh_path)
197
+ timer.end("Exporting mesh")
tsr/__pycache__/bake_texture.cpython-313.pyc ADDED
Binary file (7.68 kB). View file
 
tsr/__pycache__/system.cpython-313.pyc ADDED
Binary file (9.9 kB). View file
 
tsr/__pycache__/utils.cpython-313.pyc ADDED
Binary file (24 kB). View file
 
tsr/bake_texture.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import trimesh
4
+ from PIL import Image
5
+
6
+ try:
7
+ import xatlas
8
+ import moderngl
9
+ _HAS_XATLAS = True
10
+ except ImportError:
11
+ _HAS_XATLAS = False
12
+
13
+
14
+ def make_atlas(mesh, texture_resolution, texture_padding):
15
+ if not _HAS_XATLAS:
16
+ raise ImportError(
17
+ "xatlas is required for texture baking. Please install it with: pip install xatlas==0.0.9\n"
18
+ "Note: This requires Microsoft Visual C++ Build Tools to compile."
19
+ )
20
+ atlas = xatlas.Atlas()
21
+ atlas.add_mesh(mesh.vertices, mesh.faces)
22
+ options = xatlas.PackOptions()
23
+ options.resolution = texture_resolution
24
+ options.padding = texture_padding
25
+ options.bilinear = True
26
+ atlas.generate(pack_options=options)
27
+ vmapping, indices, uvs = atlas[0]
28
+ return {
29
+ "vmapping": vmapping,
30
+ "indices": indices,
31
+ "uvs": uvs,
32
+ }
33
+
34
+
35
+ def rasterize_position_atlas(
36
+ mesh, atlas_vmapping, atlas_indices, atlas_uvs, texture_resolution, texture_padding
37
+ ):
38
+ if not _HAS_XATLAS:
39
+ raise ImportError(
40
+ "moderngl is required for texture baking. Please install it with: pip install moderngl==5.10.0\n"
41
+ "Note: This requires Microsoft Visual C++ Build Tools to compile."
42
+ )
43
+ ctx = moderngl.create_context(standalone=True)
44
+ basic_prog = ctx.program(
45
+ vertex_shader="""
46
+ #version 330
47
+ in vec2 in_uv;
48
+ in vec3 in_pos;
49
+ out vec3 v_pos;
50
+ void main() {
51
+ v_pos = in_pos;
52
+ gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
53
+ }
54
+ """,
55
+ fragment_shader="""
56
+ #version 330
57
+ in vec3 v_pos;
58
+ out vec4 o_col;
59
+ void main() {
60
+ o_col = vec4(v_pos, 1.0);
61
+ }
62
+ """,
63
+ )
64
+ gs_prog = ctx.program(
65
+ vertex_shader="""
66
+ #version 330
67
+ in vec2 in_uv;
68
+ in vec3 in_pos;
69
+ out vec3 vg_pos;
70
+ void main() {
71
+ vg_pos = in_pos;
72
+ gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
73
+ }
74
+ """,
75
+ geometry_shader="""
76
+ #version 330
77
+ uniform float u_resolution;
78
+ uniform float u_dilation;
79
+ layout (triangles) in;
80
+ layout (triangle_strip, max_vertices = 12) out;
81
+ in vec3 vg_pos[];
82
+ out vec3 vf_pos;
83
+ void lineSegment(int aidx, int bidx) {
84
+ vec2 a = gl_in[aidx].gl_Position.xy;
85
+ vec2 b = gl_in[bidx].gl_Position.xy;
86
+ vec3 aCol = vg_pos[aidx];
87
+ vec3 bCol = vg_pos[bidx];
88
+
89
+ vec2 dir = normalize((b - a) * u_resolution);
90
+ vec2 offset = vec2(-dir.y, dir.x) * u_dilation / u_resolution;
91
+
92
+ gl_Position = vec4(a + offset, 0.0, 1.0);
93
+ vf_pos = aCol;
94
+ EmitVertex();
95
+ gl_Position = vec4(a - offset, 0.0, 1.0);
96
+ vf_pos = aCol;
97
+ EmitVertex();
98
+ gl_Position = vec4(b + offset, 0.0, 1.0);
99
+ vf_pos = bCol;
100
+ EmitVertex();
101
+ gl_Position = vec4(b - offset, 0.0, 1.0);
102
+ vf_pos = bCol;
103
+ EmitVertex();
104
+ }
105
+ void main() {
106
+ lineSegment(0, 1);
107
+ lineSegment(1, 2);
108
+ lineSegment(2, 0);
109
+ EndPrimitive();
110
+ }
111
+ """,
112
+ fragment_shader="""
113
+ #version 330
114
+ in vec3 vf_pos;
115
+ out vec4 o_col;
116
+ void main() {
117
+ o_col = vec4(vf_pos, 1.0);
118
+ }
119
+ """,
120
+ )
121
+ uvs = atlas_uvs.flatten().astype("f4")
122
+ pos = mesh.vertices[atlas_vmapping].flatten().astype("f4")
123
+ indices = atlas_indices.flatten().astype("i4")
124
+ vbo_uvs = ctx.buffer(uvs)
125
+ vbo_pos = ctx.buffer(pos)
126
+ ibo = ctx.buffer(indices)
127
+ vao_content = [
128
+ vbo_uvs.bind("in_uv", layout="2f"),
129
+ vbo_pos.bind("in_pos", layout="3f"),
130
+ ]
131
+ basic_vao = ctx.vertex_array(basic_prog, vao_content, ibo)
132
+ gs_vao = ctx.vertex_array(gs_prog, vao_content, ibo)
133
+ fbo = ctx.framebuffer(
134
+ color_attachments=[
135
+ ctx.texture((texture_resolution, texture_resolution), 4, dtype="f4")
136
+ ]
137
+ )
138
+ fbo.use()
139
+ fbo.clear(0.0, 0.0, 0.0, 0.0)
140
+ gs_prog["u_resolution"].value = texture_resolution
141
+ gs_prog["u_dilation"].value = texture_padding
142
+ gs_vao.render()
143
+ basic_vao.render()
144
+
145
+ fbo_bytes = fbo.color_attachments[0].read()
146
+ fbo_np = np.frombuffer(fbo_bytes, dtype="f4").reshape(
147
+ texture_resolution, texture_resolution, 4
148
+ )
149
+ return fbo_np
150
+
151
+
152
+ def positions_to_colors(model, scene_code, positions_texture, texture_resolution):
153
+ positions = torch.tensor(positions_texture.reshape(-1, 4)[:, :-1])
154
+ with torch.no_grad():
155
+ queried_grid = model.renderer.query_triplane(
156
+ model.decoder,
157
+ positions,
158
+ scene_code,
159
+ )
160
+ rgb_f = queried_grid["color"].numpy().reshape(-1, 3)
161
+ rgba_f = np.insert(rgb_f, 3, positions_texture.reshape(-1, 4)[:, -1], axis=1)
162
+ rgba_f[rgba_f[:, -1] == 0.0] = [0, 0, 0, 0]
163
+ return rgba_f.reshape(texture_resolution, texture_resolution, 4)
164
+
165
+
166
+ def bake_texture(mesh, model, scene_code, texture_resolution):
167
+ if not _HAS_XATLAS:
168
+ raise ImportError(
169
+ "xatlas and moderngl are required for texture baking. Please install them with:\n"
170
+ " pip install xatlas==0.0.9 moderngl==5.10.0\n"
171
+ "Note: These require Microsoft Visual C++ Build Tools to compile."
172
+ )
173
+ texture_padding = round(max(2, texture_resolution / 256))
174
+ atlas = make_atlas(mesh, texture_resolution, texture_padding)
175
+ positions_texture = rasterize_position_atlas(
176
+ mesh,
177
+ atlas["vmapping"],
178
+ atlas["indices"],
179
+ atlas["uvs"],
180
+ texture_resolution,
181
+ texture_padding,
182
+ )
183
+ colors_texture = positions_to_colors(
184
+ model, scene_code, positions_texture, texture_resolution
185
+ )
186
+ return {
187
+ "vmapping": atlas["vmapping"],
188
+ "indices": atlas["indices"],
189
+ "uvs": atlas["uvs"],
190
+ "colors": colors_texture,
191
+ }
tsr/models/__pycache__/isosurface.cpython-313.pyc ADDED
Binary file (4.26 kB). View file
 
tsr/models/__pycache__/nerf_renderer.cpython-313.pyc ADDED
Binary file (8.46 kB). View file
 
tsr/models/__pycache__/network_utils.cpython-313.pyc ADDED
Binary file (5.88 kB). View file
 
tsr/models/isosurface.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ try:
8
+ from torchmcubes import marching_cubes
9
+ _HAS_TORCHMCUBES = True
10
+ except ImportError:
11
+ _HAS_TORCHMCUBES = False
12
+ marching_cubes = None
13
+
14
+
15
+ class IsosurfaceHelper(nn.Module):
16
+ points_range: Tuple[float, float] = (0, 1)
17
+
18
+ @property
19
+ def grid_vertices(self) -> torch.FloatTensor:
20
+ raise NotImplementedError
21
+
22
+
23
+ class MarchingCubeHelper(IsosurfaceHelper):
24
+ def __init__(self, resolution: int) -> None:
25
+ super().__init__()
26
+ if not _HAS_TORCHMCUBES:
27
+ raise ImportError(
28
+ "torchmcubes is required for mesh extraction. Please install it with:\n"
29
+ " pip install git+https://github.com/tatsy/torchmcubes.git\n"
30
+ "Note: This requires Microsoft Visual C++ Build Tools to compile."
31
+ )
32
+ self.resolution = resolution
33
+ self.mc_func: Callable = marching_cubes
34
+ self._grid_vertices: Optional[torch.FloatTensor] = None
35
+
36
+ @property
37
+ def grid_vertices(self) -> torch.FloatTensor:
38
+ if self._grid_vertices is None:
39
+ # keep the vertices on CPU so that we can support very large resolution
40
+ x, y, z = (
41
+ torch.linspace(*self.points_range, self.resolution),
42
+ torch.linspace(*self.points_range, self.resolution),
43
+ torch.linspace(*self.points_range, self.resolution),
44
+ )
45
+ x, y, z = torch.meshgrid(x, y, z, indexing="ij")
46
+ verts = torch.cat(
47
+ [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
48
+ ).reshape(-1, 3)
49
+ self._grid_vertices = verts
50
+ return self._grid_vertices
51
+
52
+ def forward(
53
+ self,
54
+ level: torch.FloatTensor,
55
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
56
+ level = -level.view(self.resolution, self.resolution, self.resolution)
57
+ try:
58
+ v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
59
+ except AttributeError:
60
+ print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
61
+ v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
62
+ v_pos = v_pos[..., [2, 1, 0]]
63
+ v_pos = v_pos / (self.resolution - 1.0)
64
+ return v_pos.to(level.device), t_pos_idx.to(level.device)
tsr/models/nerf_renderer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, reduce
7
+
8
+ from ..utils import (
9
+ BaseModule,
10
+ chunk_batch,
11
+ get_activation,
12
+ rays_intersect_bbox,
13
+ scale_tensor,
14
+ )
15
+
16
+
17
+ class TriplaneNeRFRenderer(BaseModule):
18
+ @dataclass
19
+ class Config(BaseModule.Config):
20
+ radius: float
21
+
22
+ feature_reduction: str = "concat"
23
+ density_activation: str = "trunc_exp"
24
+ density_bias: float = -1.0
25
+ color_activation: str = "sigmoid"
26
+ num_samples_per_ray: int = 128
27
+ randomized: bool = False
28
+
29
+ cfg: Config
30
+
31
+ def configure(self) -> None:
32
+ assert self.cfg.feature_reduction in ["concat", "mean"]
33
+ self.chunk_size = 0
34
+
35
+ def set_chunk_size(self, chunk_size: int):
36
+ assert (
37
+ chunk_size >= 0
38
+ ), "chunk_size must be a non-negative integer (0 for no chunking)."
39
+ self.chunk_size = chunk_size
40
+
41
+ def query_triplane(
42
+ self,
43
+ decoder: torch.nn.Module,
44
+ positions: torch.Tensor,
45
+ triplane: torch.Tensor,
46
+ ) -> Dict[str, torch.Tensor]:
47
+ input_shape = positions.shape[:-1]
48
+ positions = positions.view(-1, 3)
49
+
50
+ # positions in (-radius, radius)
51
+ # normalized to (-1, 1) for grid sample
52
+ positions = scale_tensor(
53
+ positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
54
+ )
55
+
56
+ def _query_chunk(x):
57
+ indices2D: torch.Tensor = torch.stack(
58
+ (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
59
+ dim=-3,
60
+ )
61
+ out: torch.Tensor = F.grid_sample(
62
+ rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
63
+ rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
64
+ align_corners=False,
65
+ mode="bilinear",
66
+ )
67
+ if self.cfg.feature_reduction == "concat":
68
+ out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
69
+ elif self.cfg.feature_reduction == "mean":
70
+ out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ net_out: Dict[str, torch.Tensor] = decoder(out)
75
+ return net_out
76
+
77
+ if self.chunk_size > 0:
78
+ net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
79
+ else:
80
+ net_out = _query_chunk(positions)
81
+
82
+ net_out["density_act"] = get_activation(self.cfg.density_activation)(
83
+ net_out["density"] + self.cfg.density_bias
84
+ )
85
+ net_out["color"] = get_activation(self.cfg.color_activation)(
86
+ net_out["features"]
87
+ )
88
+
89
+ net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
90
+
91
+ return net_out
92
+
93
+ def _forward(
94
+ self,
95
+ decoder: torch.nn.Module,
96
+ triplane: torch.Tensor,
97
+ rays_o: torch.Tensor,
98
+ rays_d: torch.Tensor,
99
+ **kwargs,
100
+ ):
101
+ rays_shape = rays_o.shape[:-1]
102
+ rays_o = rays_o.view(-1, 3)
103
+ rays_d = rays_d.view(-1, 3)
104
+ n_rays = rays_o.shape[0]
105
+
106
+ t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
107
+ t_near, t_far = t_near[rays_valid], t_far[rays_valid]
108
+
109
+ t_vals = torch.linspace(
110
+ 0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
111
+ )
112
+ t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
113
+ z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
114
+
115
+ xyz = (
116
+ rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
117
+ ) # (N_rays, N_sample, 3)
118
+
119
+ mlp_out = self.query_triplane(
120
+ decoder=decoder,
121
+ positions=xyz,
122
+ triplane=triplane,
123
+ )
124
+
125
+ eps = 1e-10
126
+ # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
127
+ deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
128
+ alpha = 1 - torch.exp(
129
+ -deltas * mlp_out["density_act"][..., 0]
130
+ ) # (N_rays, N_samples)
131
+ accum_prod = torch.cat(
132
+ [
133
+ torch.ones_like(alpha[:, :1]),
134
+ torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
135
+ ],
136
+ dim=-1,
137
+ )
138
+ weights = alpha * accum_prod # (N_rays, N_samples)
139
+ comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
140
+ opacity_ = weights.sum(dim=-1) # (N_rays)
141
+
142
+ comp_rgb = torch.zeros(
143
+ n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
144
+ )
145
+ opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
146
+ comp_rgb[rays_valid] = comp_rgb_
147
+ opacity[rays_valid] = opacity_
148
+
149
+ comp_rgb += 1 - opacity[..., None]
150
+ comp_rgb = comp_rgb.view(*rays_shape, 3)
151
+
152
+ return comp_rgb
153
+
154
+ def forward(
155
+ self,
156
+ decoder: torch.nn.Module,
157
+ triplane: torch.Tensor,
158
+ rays_o: torch.Tensor,
159
+ rays_d: torch.Tensor,
160
+ ) -> Dict[str, torch.Tensor]:
161
+ if triplane.ndim == 4:
162
+ comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
163
+ else:
164
+ comp_rgb = torch.stack(
165
+ [
166
+ self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
167
+ for i in range(triplane.shape[0])
168
+ ],
169
+ dim=0,
170
+ )
171
+
172
+ return comp_rgb
173
+
174
+ def train(self, mode=True):
175
+ self.randomized = mode and self.cfg.randomized
176
+ return super().train(mode=mode)
177
+
178
+ def eval(self):
179
+ self.randomized = False
180
+ return super().eval()
tsr/models/network_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from ..utils import BaseModule
9
+
10
+
11
+ class TriplaneUpsampleNetwork(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ in_channels: int
15
+ out_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.upsample = nn.ConvTranspose2d(
21
+ self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
22
+ )
23
+
24
+ def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
25
+ triplanes_up = rearrange(
26
+ self.upsample(
27
+ rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
28
+ ),
29
+ "(B Np) Co Hp Wp -> B Np Co Hp Wp",
30
+ Np=3,
31
+ )
32
+ return triplanes_up
33
+
34
+
35
+ class NeRFMLP(BaseModule):
36
+ @dataclass
37
+ class Config(BaseModule.Config):
38
+ in_channels: int
39
+ n_neurons: int
40
+ n_hidden_layers: int
41
+ activation: str = "relu"
42
+ bias: bool = True
43
+ weight_init: Optional[str] = "kaiming_uniform"
44
+ bias_init: Optional[str] = None
45
+
46
+ cfg: Config
47
+
48
+ def configure(self) -> None:
49
+ layers = [
50
+ self.make_linear(
51
+ self.cfg.in_channels,
52
+ self.cfg.n_neurons,
53
+ bias=self.cfg.bias,
54
+ weight_init=self.cfg.weight_init,
55
+ bias_init=self.cfg.bias_init,
56
+ ),
57
+ self.make_activation(self.cfg.activation),
58
+ ]
59
+ for i in range(self.cfg.n_hidden_layers - 1):
60
+ layers += [
61
+ self.make_linear(
62
+ self.cfg.n_neurons,
63
+ self.cfg.n_neurons,
64
+ bias=self.cfg.bias,
65
+ weight_init=self.cfg.weight_init,
66
+ bias_init=self.cfg.bias_init,
67
+ ),
68
+ self.make_activation(self.cfg.activation),
69
+ ]
70
+ layers += [
71
+ self.make_linear(
72
+ self.cfg.n_neurons,
73
+ 4, # density 1 + features 3
74
+ bias=self.cfg.bias,
75
+ weight_init=self.cfg.weight_init,
76
+ bias_init=self.cfg.bias_init,
77
+ )
78
+ ]
79
+ self.layers = nn.Sequential(*layers)
80
+
81
+ def make_linear(
82
+ self,
83
+ dim_in,
84
+ dim_out,
85
+ bias=True,
86
+ weight_init=None,
87
+ bias_init=None,
88
+ ):
89
+ layer = nn.Linear(dim_in, dim_out, bias=bias)
90
+
91
+ if weight_init is None:
92
+ pass
93
+ elif weight_init == "kaiming_uniform":
94
+ torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ if bias:
99
+ if bias_init is None:
100
+ pass
101
+ elif bias_init == "zero":
102
+ torch.nn.init.zeros_(layer.bias)
103
+ else:
104
+ raise NotImplementedError
105
+
106
+ return layer
107
+
108
+ def make_activation(self, activation):
109
+ if activation == "relu":
110
+ return nn.ReLU(inplace=True)
111
+ elif activation == "silu":
112
+ return nn.SiLU(inplace=True)
113
+ else:
114
+ raise NotImplementedError
115
+
116
+ def forward(self, x):
117
+ inp_shape = x.shape[:-1]
118
+ x = x.reshape(-1, x.shape[-1])
119
+
120
+ features = self.layers(x)
121
+ features = features.reshape(*inp_shape, -1)
122
+ out = {"density": features[..., 0:1], "features": features[..., 1:4]}
123
+
124
+ return out
tsr/models/tokenizers/__pycache__/image.cpython-313.pyc ADDED
Binary file (3.59 kB). View file
 
tsr/models/tokenizers/__pycache__/triplane.cpython-313.pyc ADDED
Binary file (2.77 kB). View file
 
tsr/models/tokenizers/image.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers.models.vit.modeling_vit import ViTModel
8
+
9
+ from ...utils import BaseModule
10
+
11
+
12
+ class DINOSingleImageTokenizer(BaseModule):
13
+ @dataclass
14
+ class Config(BaseModule.Config):
15
+ pretrained_model_name_or_path: str = "facebook/dino-vitb16"
16
+ enable_gradient_checkpointing: bool = False
17
+
18
+ cfg: Config
19
+
20
+ def configure(self) -> None:
21
+ self.model: ViTModel = ViTModel(
22
+ ViTModel.config_class.from_pretrained(
23
+ hf_hub_download(
24
+ repo_id=self.cfg.pretrained_model_name_or_path,
25
+ filename="config.json",
26
+ )
27
+ )
28
+ )
29
+
30
+ if self.cfg.enable_gradient_checkpointing:
31
+ self.model.encoder.gradient_checkpointing = True
32
+
33
+ self.register_buffer(
34
+ "image_mean",
35
+ torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
36
+ persistent=False,
37
+ )
38
+ self.register_buffer(
39
+ "image_std",
40
+ torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
41
+ persistent=False,
42
+ )
43
+
44
+ def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
45
+ packed = False
46
+ if images.ndim == 4:
47
+ packed = True
48
+ images = images.unsqueeze(1)
49
+
50
+ batch_size, n_input_views = images.shape[:2]
51
+ images = (images - self.image_mean) / self.image_std
52
+ out = self.model(
53
+ rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
54
+ )
55
+ local_features, global_features = out.last_hidden_state, out.pooler_output
56
+ local_features = local_features.permute(0, 2, 1)
57
+ local_features = rearrange(
58
+ local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
59
+ )
60
+ if packed:
61
+ local_features = local_features.squeeze(1)
62
+
63
+ return local_features
64
+
65
+ def detokenize(self, *args, **kwargs):
66
+ raise NotImplementedError
tsr/models/tokenizers/triplane.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+
8
+ from ...utils import BaseModule
9
+
10
+
11
+ class Triplane1DTokenizer(BaseModule):
12
+ @dataclass
13
+ class Config(BaseModule.Config):
14
+ plane_size: int
15
+ num_channels: int
16
+
17
+ cfg: Config
18
+
19
+ def configure(self) -> None:
20
+ self.embeddings = nn.Parameter(
21
+ torch.randn(
22
+ (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
23
+ dtype=torch.float32,
24
+ )
25
+ * 1
26
+ / math.sqrt(self.cfg.num_channels)
27
+ )
28
+
29
+ def forward(self, batch_size: int) -> torch.Tensor:
30
+ return rearrange(
31
+ repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
32
+ "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
33
+ )
34
+
35
+ def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ batch_size, Ct, Nt = tokens.shape
37
+ assert Nt == self.cfg.plane_size**2 * 3
38
+ assert Ct == self.cfg.num_channels
39
+ return rearrange(
40
+ tokens,
41
+ "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
42
+ Np=3,
43
+ Hp=self.cfg.plane_size,
44
+ Wp=self.cfg.plane_size,
45
+ )
tsr/models/transformer/__pycache__/attention.cpython-313.pyc ADDED
Binary file (23.5 kB). View file
 
tsr/models/transformer/__pycache__/basic_transformer_block.cpython-313.pyc ADDED
Binary file (13.6 kB). View file
 
tsr/models/transformer/__pycache__/transformer_1d.cpython-313.pyc ADDED
Binary file (7.52 kB). View file
 
tsr/models/transformer/attention.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+
46
+ class Attention(nn.Module):
47
+ r"""
48
+ A cross attention layer.
49
+
50
+ Parameters:
51
+ query_dim (`int`):
52
+ The number of channels in the query.
53
+ cross_attention_dim (`int`, *optional*):
54
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
55
+ heads (`int`, *optional*, defaults to 8):
56
+ The number of heads to use for multi-head attention.
57
+ dim_head (`int`, *optional*, defaults to 64):
58
+ The number of channels in each head.
59
+ dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout probability to use.
61
+ bias (`bool`, *optional*, defaults to False):
62
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
63
+ upcast_attention (`bool`, *optional*, defaults to False):
64
+ Set to `True` to upcast the attention computation to `float32`.
65
+ upcast_softmax (`bool`, *optional*, defaults to False):
66
+ Set to `True` to upcast the softmax computation to `float32`.
67
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
68
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
69
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
70
+ The number of groups to use for the group norm in the cross attention.
71
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
72
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
73
+ norm_num_groups (`int`, *optional*, defaults to `None`):
74
+ The number of groups to use for the group norm in the attention.
75
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
76
+ The number of channels to use for the spatial normalization.
77
+ out_bias (`bool`, *optional*, defaults to `True`):
78
+ Set to `True` to use a bias in the output linear layer.
79
+ scale_qk (`bool`, *optional*, defaults to `True`):
80
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
81
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
82
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
83
+ `added_kv_proj_dim` is not `None`.
84
+ eps (`float`, *optional*, defaults to 1e-5):
85
+ An additional value added to the denominator in group normalization that is used for numerical stability.
86
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
87
+ A factor to rescale the output by dividing it with this value.
88
+ residual_connection (`bool`, *optional*, defaults to `False`):
89
+ Set to `True` to add the residual connection to the output.
90
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
91
+ Set to `True` if the attention block is loaded from a deprecated state dict.
92
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
93
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
94
+ `AttnProcessor` otherwise.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ query_dim: int,
100
+ cross_attention_dim: Optional[int] = None,
101
+ heads: int = 8,
102
+ dim_head: int = 64,
103
+ dropout: float = 0.0,
104
+ bias: bool = False,
105
+ upcast_attention: bool = False,
106
+ upcast_softmax: bool = False,
107
+ cross_attention_norm: Optional[str] = None,
108
+ cross_attention_norm_num_groups: int = 32,
109
+ added_kv_proj_dim: Optional[int] = None,
110
+ norm_num_groups: Optional[int] = None,
111
+ out_bias: bool = True,
112
+ scale_qk: bool = True,
113
+ only_cross_attention: bool = False,
114
+ eps: float = 1e-5,
115
+ rescale_output_factor: float = 1.0,
116
+ residual_connection: bool = False,
117
+ _from_deprecated_attn_block: bool = False,
118
+ processor: Optional["AttnProcessor"] = None,
119
+ out_dim: int = None,
120
+ ):
121
+ super().__init__()
122
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
123
+ self.query_dim = query_dim
124
+ self.cross_attention_dim = (
125
+ cross_attention_dim if cross_attention_dim is not None else query_dim
126
+ )
127
+ self.upcast_attention = upcast_attention
128
+ self.upcast_softmax = upcast_softmax
129
+ self.rescale_output_factor = rescale_output_factor
130
+ self.residual_connection = residual_connection
131
+ self.dropout = dropout
132
+ self.fused_projections = False
133
+ self.out_dim = out_dim if out_dim is not None else query_dim
134
+
135
+ # we make use of this private variable to know whether this class is loaded
136
+ # with an deprecated state dict so that we can convert it on the fly
137
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
138
+
139
+ self.scale_qk = scale_qk
140
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
141
+
142
+ self.heads = out_dim // dim_head if out_dim is not None else heads
143
+ # for slice_size > 0 the attention score computation
144
+ # is split across the batch axis to save memory
145
+ # You can set slice_size with `set_attention_slice`
146
+ self.sliceable_head_dim = heads
147
+
148
+ self.added_kv_proj_dim = added_kv_proj_dim
149
+ self.only_cross_attention = only_cross_attention
150
+
151
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
152
+ raise ValueError(
153
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
154
+ )
155
+
156
+ if norm_num_groups is not None:
157
+ self.group_norm = nn.GroupNorm(
158
+ num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.group_norm = None
162
+
163
+ self.spatial_norm = None
164
+
165
+ if cross_attention_norm is None:
166
+ self.norm_cross = None
167
+ elif cross_attention_norm == "layer_norm":
168
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
169
+ elif cross_attention_norm == "group_norm":
170
+ if self.added_kv_proj_dim is not None:
171
+ # The given `encoder_hidden_states` are initially of shape
172
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
173
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
174
+ # before the projection, so we need to use `added_kv_proj_dim` as
175
+ # the number of channels for the group norm.
176
+ norm_cross_num_channels = added_kv_proj_dim
177
+ else:
178
+ norm_cross_num_channels = self.cross_attention_dim
179
+
180
+ self.norm_cross = nn.GroupNorm(
181
+ num_channels=norm_cross_num_channels,
182
+ num_groups=cross_attention_norm_num_groups,
183
+ eps=1e-5,
184
+ affine=True,
185
+ )
186
+ else:
187
+ raise ValueError(
188
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
189
+ )
190
+
191
+ linear_cls = nn.Linear
192
+
193
+ self.linear_cls = linear_cls
194
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
195
+
196
+ if not self.only_cross_attention:
197
+ # only relevant for the `AddedKVProcessor` classes
198
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
199
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
200
+ else:
201
+ self.to_k = None
202
+ self.to_v = None
203
+
204
+ if self.added_kv_proj_dim is not None:
205
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
206
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
207
+
208
+ self.to_out = nn.ModuleList([])
209
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
210
+ self.to_out.append(nn.Dropout(dropout))
211
+
212
+ # set attention processor
213
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
214
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
215
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
216
+ if processor is None:
217
+ processor = (
218
+ AttnProcessor2_0()
219
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
220
+ else AttnProcessor()
221
+ )
222
+ self.set_processor(processor)
223
+
224
+ def set_processor(self, processor: "AttnProcessor") -> None:
225
+ self.processor = processor
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states: torch.FloatTensor,
230
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
231
+ attention_mask: Optional[torch.FloatTensor] = None,
232
+ **cross_attention_kwargs,
233
+ ) -> torch.Tensor:
234
+ r"""
235
+ The forward method of the `Attention` class.
236
+
237
+ Args:
238
+ hidden_states (`torch.Tensor`):
239
+ The hidden states of the query.
240
+ encoder_hidden_states (`torch.Tensor`, *optional*):
241
+ The hidden states of the encoder.
242
+ attention_mask (`torch.Tensor`, *optional*):
243
+ The attention mask to use. If `None`, no mask is applied.
244
+ **cross_attention_kwargs:
245
+ Additional keyword arguments to pass along to the cross attention.
246
+
247
+ Returns:
248
+ `torch.Tensor`: The output of the attention layer.
249
+ """
250
+ # The `Attention` class can call different attention processors / attention functions
251
+ # here we simply pass along all tensors to the selected processor class
252
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
253
+ return self.processor(
254
+ self,
255
+ hidden_states,
256
+ encoder_hidden_states=encoder_hidden_states,
257
+ attention_mask=attention_mask,
258
+ **cross_attention_kwargs,
259
+ )
260
+
261
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
262
+ r"""
263
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
264
+ is the number of heads initialized while constructing the `Attention` class.
265
+
266
+ Args:
267
+ tensor (`torch.Tensor`): The tensor to reshape.
268
+
269
+ Returns:
270
+ `torch.Tensor`: The reshaped tensor.
271
+ """
272
+ head_size = self.heads
273
+ batch_size, seq_len, dim = tensor.shape
274
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
275
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
276
+ batch_size // head_size, seq_len, dim * head_size
277
+ )
278
+ return tensor
279
+
280
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
281
+ r"""
282
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
283
+ the number of heads initialized while constructing the `Attention` class.
284
+
285
+ Args:
286
+ tensor (`torch.Tensor`): The tensor to reshape.
287
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
288
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
289
+
290
+ Returns:
291
+ `torch.Tensor`: The reshaped tensor.
292
+ """
293
+ head_size = self.heads
294
+ batch_size, seq_len, dim = tensor.shape
295
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
296
+ tensor = tensor.permute(0, 2, 1, 3)
297
+
298
+ if out_dim == 3:
299
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
300
+
301
+ return tensor
302
+
303
+ def get_attention_scores(
304
+ self,
305
+ query: torch.Tensor,
306
+ key: torch.Tensor,
307
+ attention_mask: torch.Tensor = None,
308
+ ) -> torch.Tensor:
309
+ r"""
310
+ Compute the attention scores.
311
+
312
+ Args:
313
+ query (`torch.Tensor`): The query tensor.
314
+ key (`torch.Tensor`): The key tensor.
315
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
316
+
317
+ Returns:
318
+ `torch.Tensor`: The attention probabilities/scores.
319
+ """
320
+ dtype = query.dtype
321
+ if self.upcast_attention:
322
+ query = query.float()
323
+ key = key.float()
324
+
325
+ if attention_mask is None:
326
+ baddbmm_input = torch.empty(
327
+ query.shape[0],
328
+ query.shape[1],
329
+ key.shape[1],
330
+ dtype=query.dtype,
331
+ device=query.device,
332
+ )
333
+ beta = 0
334
+ else:
335
+ baddbmm_input = attention_mask
336
+ beta = 1
337
+
338
+ attention_scores = torch.baddbmm(
339
+ baddbmm_input,
340
+ query,
341
+ key.transpose(-1, -2),
342
+ beta=beta,
343
+ alpha=self.scale,
344
+ )
345
+ del baddbmm_input
346
+
347
+ if self.upcast_softmax:
348
+ attention_scores = attention_scores.float()
349
+
350
+ attention_probs = attention_scores.softmax(dim=-1)
351
+ del attention_scores
352
+
353
+ attention_probs = attention_probs.to(dtype)
354
+
355
+ return attention_probs
356
+
357
+ def prepare_attention_mask(
358
+ self,
359
+ attention_mask: torch.Tensor,
360
+ target_length: int,
361
+ batch_size: int,
362
+ out_dim: int = 3,
363
+ ) -> torch.Tensor:
364
+ r"""
365
+ Prepare the attention mask for the attention computation.
366
+
367
+ Args:
368
+ attention_mask (`torch.Tensor`):
369
+ The attention mask to prepare.
370
+ target_length (`int`):
371
+ The target length of the attention mask. This is the length of the attention mask after padding.
372
+ batch_size (`int`):
373
+ The batch size, which is used to repeat the attention mask.
374
+ out_dim (`int`, *optional*, defaults to `3`):
375
+ The output dimension of the attention mask. Can be either `3` or `4`.
376
+
377
+ Returns:
378
+ `torch.Tensor`: The prepared attention mask.
379
+ """
380
+ head_size = self.heads
381
+ if attention_mask is None:
382
+ return attention_mask
383
+
384
+ current_length: int = attention_mask.shape[-1]
385
+ if current_length != target_length:
386
+ if attention_mask.device.type == "mps":
387
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
388
+ # Instead, we can manually construct the padding tensor.
389
+ padding_shape = (
390
+ attention_mask.shape[0],
391
+ attention_mask.shape[1],
392
+ target_length,
393
+ )
394
+ padding = torch.zeros(
395
+ padding_shape,
396
+ dtype=attention_mask.dtype,
397
+ device=attention_mask.device,
398
+ )
399
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
400
+ else:
401
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
402
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
403
+ # remaining_length: int = target_length - current_length
404
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
405
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
406
+
407
+ if out_dim == 3:
408
+ if attention_mask.shape[0] < batch_size * head_size:
409
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
410
+ elif out_dim == 4:
411
+ attention_mask = attention_mask.unsqueeze(1)
412
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
413
+
414
+ return attention_mask
415
+
416
+ def norm_encoder_hidden_states(
417
+ self, encoder_hidden_states: torch.Tensor
418
+ ) -> torch.Tensor:
419
+ r"""
420
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
421
+ `Attention` class.
422
+
423
+ Args:
424
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
425
+
426
+ Returns:
427
+ `torch.Tensor`: The normalized encoder hidden states.
428
+ """
429
+ assert (
430
+ self.norm_cross is not None
431
+ ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
432
+
433
+ if isinstance(self.norm_cross, nn.LayerNorm):
434
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
435
+ elif isinstance(self.norm_cross, nn.GroupNorm):
436
+ # Group norm norms along the channels dimension and expects
437
+ # input to be in the shape of (N, C, *). In this case, we want
438
+ # to norm along the hidden dimension, so we need to move
439
+ # (batch_size, sequence_length, hidden_size) ->
440
+ # (batch_size, hidden_size, sequence_length)
441
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
442
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
443
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
444
+ else:
445
+ assert False
446
+
447
+ return encoder_hidden_states
448
+
449
+ @torch.no_grad()
450
+ def fuse_projections(self, fuse=True):
451
+ is_cross_attention = self.cross_attention_dim != self.query_dim
452
+ device = self.to_q.weight.data.device
453
+ dtype = self.to_q.weight.data.dtype
454
+
455
+ if not is_cross_attention:
456
+ # fetch weight matrices.
457
+ concatenated_weights = torch.cat(
458
+ [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
459
+ )
460
+ in_features = concatenated_weights.shape[1]
461
+ out_features = concatenated_weights.shape[0]
462
+
463
+ # create a new single projection layer and copy over the weights.
464
+ self.to_qkv = self.linear_cls(
465
+ in_features, out_features, bias=False, device=device, dtype=dtype
466
+ )
467
+ self.to_qkv.weight.copy_(concatenated_weights)
468
+
469
+ else:
470
+ concatenated_weights = torch.cat(
471
+ [self.to_k.weight.data, self.to_v.weight.data]
472
+ )
473
+ in_features = concatenated_weights.shape[1]
474
+ out_features = concatenated_weights.shape[0]
475
+
476
+ self.to_kv = self.linear_cls(
477
+ in_features, out_features, bias=False, device=device, dtype=dtype
478
+ )
479
+ self.to_kv.weight.copy_(concatenated_weights)
480
+
481
+ self.fused_projections = fuse
482
+
483
+
484
+ class AttnProcessor:
485
+ r"""
486
+ Default processor for performing attention-related computations.
487
+ """
488
+
489
+ def __call__(
490
+ self,
491
+ attn: Attention,
492
+ hidden_states: torch.FloatTensor,
493
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
494
+ attention_mask: Optional[torch.FloatTensor] = None,
495
+ ) -> torch.Tensor:
496
+ residual = hidden_states
497
+
498
+ input_ndim = hidden_states.ndim
499
+
500
+ if input_ndim == 4:
501
+ batch_size, channel, height, width = hidden_states.shape
502
+ hidden_states = hidden_states.view(
503
+ batch_size, channel, height * width
504
+ ).transpose(1, 2)
505
+
506
+ batch_size, sequence_length, _ = (
507
+ hidden_states.shape
508
+ if encoder_hidden_states is None
509
+ else encoder_hidden_states.shape
510
+ )
511
+ attention_mask = attn.prepare_attention_mask(
512
+ attention_mask, sequence_length, batch_size
513
+ )
514
+
515
+ if attn.group_norm is not None:
516
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
517
+ 1, 2
518
+ )
519
+
520
+ query = attn.to_q(hidden_states)
521
+
522
+ if encoder_hidden_states is None:
523
+ encoder_hidden_states = hidden_states
524
+ elif attn.norm_cross:
525
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
526
+ encoder_hidden_states
527
+ )
528
+
529
+ key = attn.to_k(encoder_hidden_states)
530
+ value = attn.to_v(encoder_hidden_states)
531
+
532
+ query = attn.head_to_batch_dim(query)
533
+ key = attn.head_to_batch_dim(key)
534
+ value = attn.head_to_batch_dim(value)
535
+
536
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
537
+ hidden_states = torch.bmm(attention_probs, value)
538
+ hidden_states = attn.batch_to_head_dim(hidden_states)
539
+
540
+ # linear proj
541
+ hidden_states = attn.to_out[0](hidden_states)
542
+ # dropout
543
+ hidden_states = attn.to_out[1](hidden_states)
544
+
545
+ if input_ndim == 4:
546
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
547
+ batch_size, channel, height, width
548
+ )
549
+
550
+ if attn.residual_connection:
551
+ hidden_states = hidden_states + residual
552
+
553
+ hidden_states = hidden_states / attn.rescale_output_factor
554
+
555
+ return hidden_states
556
+
557
+
558
+ class AttnProcessor2_0:
559
+ r"""
560
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
561
+ """
562
+
563
+ def __init__(self):
564
+ if not hasattr(F, "scaled_dot_product_attention"):
565
+ raise ImportError(
566
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
567
+ )
568
+
569
+ def __call__(
570
+ self,
571
+ attn: Attention,
572
+ hidden_states: torch.FloatTensor,
573
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
574
+ attention_mask: Optional[torch.FloatTensor] = None,
575
+ ) -> torch.FloatTensor:
576
+ residual = hidden_states
577
+
578
+ input_ndim = hidden_states.ndim
579
+
580
+ if input_ndim == 4:
581
+ batch_size, channel, height, width = hidden_states.shape
582
+ hidden_states = hidden_states.view(
583
+ batch_size, channel, height * width
584
+ ).transpose(1, 2)
585
+
586
+ batch_size, sequence_length, _ = (
587
+ hidden_states.shape
588
+ if encoder_hidden_states is None
589
+ else encoder_hidden_states.shape
590
+ )
591
+
592
+ if attention_mask is not None:
593
+ attention_mask = attn.prepare_attention_mask(
594
+ attention_mask, sequence_length, batch_size
595
+ )
596
+ # scaled_dot_product_attention expects attention_mask shape to be
597
+ # (batch, heads, source_length, target_length)
598
+ attention_mask = attention_mask.view(
599
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
600
+ )
601
+
602
+ if attn.group_norm is not None:
603
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
604
+ 1, 2
605
+ )
606
+
607
+ query = attn.to_q(hidden_states)
608
+
609
+ if encoder_hidden_states is None:
610
+ encoder_hidden_states = hidden_states
611
+ elif attn.norm_cross:
612
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
613
+ encoder_hidden_states
614
+ )
615
+
616
+ key = attn.to_k(encoder_hidden_states)
617
+ value = attn.to_v(encoder_hidden_states)
618
+
619
+ inner_dim = key.shape[-1]
620
+ head_dim = inner_dim // attn.heads
621
+
622
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
623
+
624
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
625
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
626
+
627
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
628
+ # TODO: add support for attn.scale when we move to Torch 2.1
629
+ hidden_states = F.scaled_dot_product_attention(
630
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
631
+ )
632
+
633
+ hidden_states = hidden_states.transpose(1, 2).reshape(
634
+ batch_size, -1, attn.heads * head_dim
635
+ )
636
+ hidden_states = hidden_states.to(query.dtype)
637
+
638
+ # linear proj
639
+ hidden_states = attn.to_out[0](hidden_states)
640
+ # dropout
641
+ hidden_states = attn.to_out[1](hidden_states)
642
+
643
+ if input_ndim == 4:
644
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
645
+ batch_size, channel, height, width
646
+ )
647
+
648
+ if attn.residual_connection:
649
+ hidden_states = hidden_states + residual
650
+
651
+ hidden_states = hidden_states / attn.rescale_output_factor
652
+
653
+ return hidden_states
tsr/models/transformer/basic_transformer_block.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from typing import Optional
40
+
41
+ import torch
42
+ import torch.nn.functional as F
43
+ from torch import nn
44
+
45
+ from .attention import Attention
46
+
47
+
48
+ class BasicTransformerBlock(nn.Module):
49
+ r"""
50
+ A basic Transformer block.
51
+
52
+ Parameters:
53
+ dim (`int`): The number of channels in the input and output.
54
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
55
+ attention_head_dim (`int`): The number of channels in each head.
56
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
57
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
58
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
59
+ attention_bias (:
60
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
61
+ only_cross_attention (`bool`, *optional*):
62
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
63
+ double_self_attention (`bool`, *optional*):
64
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
65
+ upcast_attention (`bool`, *optional*):
66
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
67
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
68
+ Whether to use learnable elementwise affine parameters for normalization.
69
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
70
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
71
+ final_dropout (`bool` *optional*, defaults to False):
72
+ Whether to apply a final dropout after the last feed-forward layer.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ dim: int,
78
+ num_attention_heads: int,
79
+ attention_head_dim: int,
80
+ dropout=0.0,
81
+ cross_attention_dim: Optional[int] = None,
82
+ activation_fn: str = "geglu",
83
+ attention_bias: bool = False,
84
+ only_cross_attention: bool = False,
85
+ double_self_attention: bool = False,
86
+ upcast_attention: bool = False,
87
+ norm_elementwise_affine: bool = True,
88
+ norm_type: str = "layer_norm",
89
+ final_dropout: bool = False,
90
+ ):
91
+ super().__init__()
92
+ self.only_cross_attention = only_cross_attention
93
+
94
+ assert norm_type == "layer_norm"
95
+
96
+ # Define 3 blocks. Each block has its own normalization layer.
97
+ # 1. Self-Attn
98
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
99
+ self.attn1 = Attention(
100
+ query_dim=dim,
101
+ heads=num_attention_heads,
102
+ dim_head=attention_head_dim,
103
+ dropout=dropout,
104
+ bias=attention_bias,
105
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
106
+ upcast_attention=upcast_attention,
107
+ )
108
+
109
+ # 2. Cross-Attn
110
+ if cross_attention_dim is not None or double_self_attention:
111
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
112
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
113
+ # the second cross attention block.
114
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
115
+
116
+ self.attn2 = Attention(
117
+ query_dim=dim,
118
+ cross_attention_dim=(
119
+ cross_attention_dim if not double_self_attention else None
120
+ ),
121
+ heads=num_attention_heads,
122
+ dim_head=attention_head_dim,
123
+ dropout=dropout,
124
+ bias=attention_bias,
125
+ upcast_attention=upcast_attention,
126
+ ) # is self-attn if encoder_hidden_states is none
127
+ else:
128
+ self.norm2 = None
129
+ self.attn2 = None
130
+
131
+ # 3. Feed-forward
132
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
133
+ self.ff = FeedForward(
134
+ dim,
135
+ dropout=dropout,
136
+ activation_fn=activation_fn,
137
+ final_dropout=final_dropout,
138
+ )
139
+
140
+ # let chunk size default to None
141
+ self._chunk_size = None
142
+ self._chunk_dim = 0
143
+
144
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
145
+ # Sets chunk feed-forward
146
+ self._chunk_size = chunk_size
147
+ self._chunk_dim = dim
148
+
149
+ def forward(
150
+ self,
151
+ hidden_states: torch.FloatTensor,
152
+ attention_mask: Optional[torch.FloatTensor] = None,
153
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
154
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
155
+ ) -> torch.FloatTensor:
156
+ # Notice that normalization is always applied before the real computation in the following blocks.
157
+ # 0. Self-Attention
158
+ norm_hidden_states = self.norm1(hidden_states)
159
+
160
+ attn_output = self.attn1(
161
+ norm_hidden_states,
162
+ encoder_hidden_states=(
163
+ encoder_hidden_states if self.only_cross_attention else None
164
+ ),
165
+ attention_mask=attention_mask,
166
+ )
167
+
168
+ hidden_states = attn_output + hidden_states
169
+
170
+ # 3. Cross-Attention
171
+ if self.attn2 is not None:
172
+ norm_hidden_states = self.norm2(hidden_states)
173
+
174
+ attn_output = self.attn2(
175
+ norm_hidden_states,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ attention_mask=encoder_attention_mask,
178
+ )
179
+ hidden_states = attn_output + hidden_states
180
+
181
+ # 4. Feed-forward
182
+ norm_hidden_states = self.norm3(hidden_states)
183
+
184
+ if self._chunk_size is not None:
185
+ # "feed_forward_chunk_size" can be used to save memory
186
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
187
+ raise ValueError(
188
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
189
+ )
190
+
191
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
192
+ ff_output = torch.cat(
193
+ [
194
+ self.ff(hid_slice)
195
+ for hid_slice in norm_hidden_states.chunk(
196
+ num_chunks, dim=self._chunk_dim
197
+ )
198
+ ],
199
+ dim=self._chunk_dim,
200
+ )
201
+ else:
202
+ ff_output = self.ff(norm_hidden_states)
203
+
204
+ hidden_states = ff_output + hidden_states
205
+
206
+ return hidden_states
207
+
208
+
209
+ class FeedForward(nn.Module):
210
+ r"""
211
+ A feed-forward layer.
212
+
213
+ Parameters:
214
+ dim (`int`): The number of channels in the input.
215
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
216
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
217
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
218
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
219
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ dim: int,
225
+ dim_out: Optional[int] = None,
226
+ mult: int = 4,
227
+ dropout: float = 0.0,
228
+ activation_fn: str = "geglu",
229
+ final_dropout: bool = False,
230
+ ):
231
+ super().__init__()
232
+ inner_dim = int(dim * mult)
233
+ dim_out = dim_out if dim_out is not None else dim
234
+ linear_cls = nn.Linear
235
+
236
+ if activation_fn == "gelu":
237
+ act_fn = GELU(dim, inner_dim)
238
+ if activation_fn == "gelu-approximate":
239
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
240
+ elif activation_fn == "geglu":
241
+ act_fn = GEGLU(dim, inner_dim)
242
+ elif activation_fn == "geglu-approximate":
243
+ act_fn = ApproximateGELU(dim, inner_dim)
244
+
245
+ self.net = nn.ModuleList([])
246
+ # project in
247
+ self.net.append(act_fn)
248
+ # project dropout
249
+ self.net.append(nn.Dropout(dropout))
250
+ # project out
251
+ self.net.append(linear_cls(inner_dim, dim_out))
252
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
253
+ if final_dropout:
254
+ self.net.append(nn.Dropout(dropout))
255
+
256
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
257
+ for module in self.net:
258
+ hidden_states = module(hidden_states)
259
+ return hidden_states
260
+
261
+
262
+ class GELU(nn.Module):
263
+ r"""
264
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
265
+
266
+ Parameters:
267
+ dim_in (`int`): The number of channels in the input.
268
+ dim_out (`int`): The number of channels in the output.
269
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
270
+ """
271
+
272
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
273
+ super().__init__()
274
+ self.proj = nn.Linear(dim_in, dim_out)
275
+ self.approximate = approximate
276
+
277
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
278
+ if gate.device.type != "mps":
279
+ return F.gelu(gate, approximate=self.approximate)
280
+ # mps: gelu is not implemented for float16
281
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
282
+ dtype=gate.dtype
283
+ )
284
+
285
+ def forward(self, hidden_states):
286
+ hidden_states = self.proj(hidden_states)
287
+ hidden_states = self.gelu(hidden_states)
288
+ return hidden_states
289
+
290
+
291
+ class GEGLU(nn.Module):
292
+ r"""
293
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
294
+
295
+ Parameters:
296
+ dim_in (`int`): The number of channels in the input.
297
+ dim_out (`int`): The number of channels in the output.
298
+ """
299
+
300
+ def __init__(self, dim_in: int, dim_out: int):
301
+ super().__init__()
302
+ linear_cls = nn.Linear
303
+
304
+ self.proj = linear_cls(dim_in, dim_out * 2)
305
+
306
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
307
+ if gate.device.type != "mps":
308
+ return F.gelu(gate)
309
+ # mps: gelu is not implemented for float16
310
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
311
+
312
+ def forward(self, hidden_states, scale: float = 1.0):
313
+ args = ()
314
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
315
+ return hidden_states * self.gelu(gate)
316
+
317
+
318
+ class ApproximateGELU(nn.Module):
319
+ r"""
320
+ The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
321
+ https://arxiv.org/abs/1606.08415.
322
+
323
+ Parameters:
324
+ dim_in (`int`): The number of channels in the input.
325
+ dim_out (`int`): The number of channels in the output.
326
+ """
327
+
328
+ def __init__(self, dim_in: int, dim_out: int):
329
+ super().__init__()
330
+ self.proj = nn.Linear(dim_in, dim_out)
331
+
332
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
333
+ x = self.proj(x)
334
+ return x * torch.sigmoid(1.702 * x)
tsr/models/transformer/transformer_1d.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # --------
16
+ #
17
+ # Modified 2024 by the Tripo AI and Stability AI Team.
18
+ #
19
+ # Copyright (c) 2024 Tripo AI & Stability AI
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+
39
+ from dataclasses import dataclass
40
+ from typing import Optional
41
+
42
+ import torch
43
+ import torch.nn.functional as F
44
+ from torch import nn
45
+
46
+ from ...utils import BaseModule
47
+ from .basic_transformer_block import BasicTransformerBlock
48
+
49
+
50
+ class Transformer1D(BaseModule):
51
+ @dataclass
52
+ class Config(BaseModule.Config):
53
+ num_attention_heads: int = 16
54
+ attention_head_dim: int = 88
55
+ in_channels: Optional[int] = None
56
+ out_channels: Optional[int] = None
57
+ num_layers: int = 1
58
+ dropout: float = 0.0
59
+ norm_num_groups: int = 32
60
+ cross_attention_dim: Optional[int] = None
61
+ attention_bias: bool = False
62
+ activation_fn: str = "geglu"
63
+ only_cross_attention: bool = False
64
+ double_self_attention: bool = False
65
+ upcast_attention: bool = False
66
+ norm_type: str = "layer_norm"
67
+ norm_elementwise_affine: bool = True
68
+ gradient_checkpointing: bool = False
69
+
70
+ cfg: Config
71
+
72
+ def configure(self) -> None:
73
+ self.num_attention_heads = self.cfg.num_attention_heads
74
+ self.attention_head_dim = self.cfg.attention_head_dim
75
+ inner_dim = self.num_attention_heads * self.attention_head_dim
76
+
77
+ linear_cls = nn.Linear
78
+
79
+ # 2. Define input layers
80
+ self.in_channels = self.cfg.in_channels
81
+
82
+ self.norm = torch.nn.GroupNorm(
83
+ num_groups=self.cfg.norm_num_groups,
84
+ num_channels=self.cfg.in_channels,
85
+ eps=1e-6,
86
+ affine=True,
87
+ )
88
+ self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
89
+
90
+ # 3. Define transformers blocks
91
+ self.transformer_blocks = nn.ModuleList(
92
+ [
93
+ BasicTransformerBlock(
94
+ inner_dim,
95
+ self.num_attention_heads,
96
+ self.attention_head_dim,
97
+ dropout=self.cfg.dropout,
98
+ cross_attention_dim=self.cfg.cross_attention_dim,
99
+ activation_fn=self.cfg.activation_fn,
100
+ attention_bias=self.cfg.attention_bias,
101
+ only_cross_attention=self.cfg.only_cross_attention,
102
+ double_self_attention=self.cfg.double_self_attention,
103
+ upcast_attention=self.cfg.upcast_attention,
104
+ norm_type=self.cfg.norm_type,
105
+ norm_elementwise_affine=self.cfg.norm_elementwise_affine,
106
+ )
107
+ for d in range(self.cfg.num_layers)
108
+ ]
109
+ )
110
+
111
+ # 4. Define output layers
112
+ self.out_channels = (
113
+ self.cfg.in_channels
114
+ if self.cfg.out_channels is None
115
+ else self.cfg.out_channels
116
+ )
117
+
118
+ self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
119
+
120
+ self.gradient_checkpointing = self.cfg.gradient_checkpointing
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ encoder_hidden_states: Optional[torch.Tensor] = None,
126
+ attention_mask: Optional[torch.Tensor] = None,
127
+ encoder_attention_mask: Optional[torch.Tensor] = None,
128
+ ):
129
+ """
130
+ The [`Transformer1DModel`] forward method.
131
+
132
+ Args:
133
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
134
+ Input `hidden_states`.
135
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
136
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
137
+ self-attention.
138
+ attention_mask ( `torch.Tensor`, *optional*):
139
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
140
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
141
+ negative values to the attention scores corresponding to "discard" tokens.
142
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
143
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
144
+
145
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
146
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
147
+
148
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
149
+ above. This bias will be added to the cross-attention scores.
150
+
151
+ Returns:
152
+ torch.FloatTensor
153
+ """
154
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
155
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
156
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
157
+ # expects mask of shape:
158
+ # [batch, key_tokens]
159
+ # adds singleton query_tokens dimension:
160
+ # [batch, 1, key_tokens]
161
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
162
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
163
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
164
+ if attention_mask is not None and attention_mask.ndim == 2:
165
+ # assume that mask is expressed as:
166
+ # (1 = keep, 0 = discard)
167
+ # convert mask into a bias that can be added to attention scores:
168
+ # (keep = +0, discard = -10000.0)
169
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
170
+ attention_mask = attention_mask.unsqueeze(1)
171
+
172
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
173
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
174
+ encoder_attention_mask = (
175
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
176
+ ) * -10000.0
177
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
178
+
179
+ # 1. Input
180
+ batch, _, seq_len = hidden_states.shape
181
+ residual = hidden_states
182
+
183
+ hidden_states = self.norm(hidden_states)
184
+ inner_dim = hidden_states.shape[1]
185
+ hidden_states = hidden_states.permute(0, 2, 1).reshape(
186
+ batch, seq_len, inner_dim
187
+ )
188
+ hidden_states = self.proj_in(hidden_states)
189
+
190
+ # 2. Blocks
191
+ for block in self.transformer_blocks:
192
+ if self.training and self.gradient_checkpointing:
193
+ hidden_states = torch.utils.checkpoint.checkpoint(
194
+ block,
195
+ hidden_states,
196
+ attention_mask,
197
+ encoder_hidden_states,
198
+ encoder_attention_mask,
199
+ use_reentrant=False,
200
+ )
201
+ else:
202
+ hidden_states = block(
203
+ hidden_states,
204
+ attention_mask=attention_mask,
205
+ encoder_hidden_states=encoder_hidden_states,
206
+ encoder_attention_mask=encoder_attention_mask,
207
+ )
208
+
209
+ # 3. Output
210
+ hidden_states = self.proj_out(hidden_states)
211
+ hidden_states = (
212
+ hidden_states.reshape(batch, seq_len, inner_dim)
213
+ .permute(0, 2, 1)
214
+ .contiguous()
215
+ )
216
+
217
+ output = hidden_states + residual
218
+
219
+ return output
tsr/system.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import List, Union
5
+
6
+ import numpy as np
7
+ import PIL.Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import trimesh
11
+ from einops import rearrange
12
+ from huggingface_hub import hf_hub_download
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+
16
+ from .models.isosurface import MarchingCubeHelper
17
+ from .utils import (
18
+ BaseModule,
19
+ ImagePreprocessor,
20
+ find_class,
21
+ get_spherical_cameras,
22
+ scale_tensor,
23
+ )
24
+
25
+
26
+ class TSR(BaseModule):
27
+ @dataclass
28
+ class Config(BaseModule.Config):
29
+ cond_image_size: int
30
+
31
+ image_tokenizer_cls: str
32
+ image_tokenizer: dict
33
+
34
+ tokenizer_cls: str
35
+ tokenizer: dict
36
+
37
+ backbone_cls: str
38
+ backbone: dict
39
+
40
+ post_processor_cls: str
41
+ post_processor: dict
42
+
43
+ decoder_cls: str
44
+ decoder: dict
45
+
46
+ renderer_cls: str
47
+ renderer: dict
48
+
49
+ cfg: Config
50
+
51
+ @classmethod
52
+ def from_pretrained(
53
+ cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
54
+ ):
55
+ if os.path.isdir(pretrained_model_name_or_path):
56
+ config_path = os.path.join(pretrained_model_name_or_path, config_name)
57
+ weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
58
+ else:
59
+ config_path = hf_hub_download(
60
+ repo_id=pretrained_model_name_or_path, filename=config_name
61
+ )
62
+ weight_path = hf_hub_download(
63
+ repo_id=pretrained_model_name_or_path, filename=weight_name
64
+ )
65
+
66
+ cfg = OmegaConf.load(config_path)
67
+ OmegaConf.resolve(cfg)
68
+ model = cls(cfg)
69
+ ckpt = torch.load(weight_path, map_location="cpu")
70
+ model.load_state_dict(ckpt)
71
+ return model
72
+
73
+ def configure(self):
74
+ self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
75
+ self.cfg.image_tokenizer
76
+ )
77
+ self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
78
+ self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
79
+ self.post_processor = find_class(self.cfg.post_processor_cls)(
80
+ self.cfg.post_processor
81
+ )
82
+ self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
83
+ self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
84
+ self.image_processor = ImagePreprocessor()
85
+ self.isosurface_helper = None
86
+
87
+ def forward(
88
+ self,
89
+ image: Union[
90
+ PIL.Image.Image,
91
+ np.ndarray,
92
+ torch.FloatTensor,
93
+ List[PIL.Image.Image],
94
+ List[np.ndarray],
95
+ List[torch.FloatTensor],
96
+ ],
97
+ device: str,
98
+ ) -> torch.FloatTensor:
99
+ rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
100
+ device
101
+ )
102
+ batch_size = rgb_cond.shape[0]
103
+
104
+ input_image_tokens: torch.Tensor = self.image_tokenizer(
105
+ rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
106
+ )
107
+
108
+ input_image_tokens = rearrange(
109
+ input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
110
+ )
111
+
112
+ tokens: torch.Tensor = self.tokenizer(batch_size)
113
+
114
+ tokens = self.backbone(
115
+ tokens,
116
+ encoder_hidden_states=input_image_tokens,
117
+ )
118
+
119
+ scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
120
+ return scene_codes
121
+
122
+ def render(
123
+ self,
124
+ scene_codes,
125
+ n_views: int,
126
+ elevation_deg: float = 0.0,
127
+ camera_distance: float = 1.9,
128
+ fovy_deg: float = 40.0,
129
+ height: int = 256,
130
+ width: int = 256,
131
+ return_type: str = "pil",
132
+ ):
133
+ rays_o, rays_d = get_spherical_cameras(
134
+ n_views, elevation_deg, camera_distance, fovy_deg, height, width
135
+ )
136
+ rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
137
+
138
+ def process_output(image: torch.FloatTensor):
139
+ if return_type == "pt":
140
+ return image
141
+ elif return_type == "np":
142
+ return image.detach().cpu().numpy()
143
+ elif return_type == "pil":
144
+ return Image.fromarray(
145
+ (image.detach().cpu().numpy() * 255.0).astype(np.uint8)
146
+ )
147
+ else:
148
+ raise NotImplementedError
149
+
150
+ images = []
151
+ for scene_code in scene_codes:
152
+ images_ = []
153
+ for i in range(n_views):
154
+ with torch.no_grad():
155
+ image = self.renderer(
156
+ self.decoder, scene_code, rays_o[i], rays_d[i]
157
+ )
158
+ images_.append(process_output(image))
159
+ images.append(images_)
160
+
161
+ return images
162
+
163
+ def set_marching_cubes_resolution(self, resolution: int):
164
+ if (
165
+ self.isosurface_helper is not None
166
+ and self.isosurface_helper.resolution == resolution
167
+ ):
168
+ return
169
+ self.isosurface_helper = MarchingCubeHelper(resolution)
170
+
171
+ def extract_mesh(self, scene_codes, has_vertex_color, resolution: int = 256, threshold: float = 25.0):
172
+ self.set_marching_cubes_resolution(resolution)
173
+ meshes = []
174
+ for scene_code in scene_codes:
175
+ with torch.no_grad():
176
+ density = self.renderer.query_triplane(
177
+ self.decoder,
178
+ scale_tensor(
179
+ self.isosurface_helper.grid_vertices.to(scene_codes.device),
180
+ self.isosurface_helper.points_range,
181
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
182
+ ),
183
+ scene_code,
184
+ )["density_act"]
185
+ v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
186
+ v_pos = scale_tensor(
187
+ v_pos,
188
+ self.isosurface_helper.points_range,
189
+ (-self.renderer.cfg.radius, self.renderer.cfg.radius),
190
+ )
191
+ color = None
192
+ if has_vertex_color:
193
+ with torch.no_grad():
194
+ color = self.renderer.query_triplane(
195
+ self.decoder,
196
+ v_pos,
197
+ scene_code,
198
+ )["color"]
199
+ mesh = trimesh.Trimesh(
200
+ vertices=v_pos.cpu().numpy(),
201
+ faces=t_pos_idx.cpu().numpy(),
202
+ vertex_colors=color.cpu().numpy() if has_vertex_color else None,
203
+ )
204
+ meshes.append(mesh)
205
+ return meshes
tsr/utils.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import math
3
+ from collections import defaultdict
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+
7
+ import imageio
8
+ import numpy as np
9
+ import PIL.Image
10
+ import rembg
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import trimesh
15
+ from omegaconf import DictConfig, OmegaConf
16
+ from PIL import Image
17
+
18
+
19
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
20
+ scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
21
+ return scfg
22
+
23
+
24
+ def find_class(cls_string):
25
+ module_string = ".".join(cls_string.split(".")[:-1])
26
+ cls_name = cls_string.split(".")[-1]
27
+ module = importlib.import_module(module_string, package=None)
28
+ cls = getattr(module, cls_name)
29
+ return cls
30
+
31
+
32
+ def get_intrinsic_from_fov(fov, H, W, bs=-1):
33
+ focal_length = 0.5 * H / np.tan(0.5 * fov)
34
+ intrinsic = np.identity(3, dtype=np.float32)
35
+ intrinsic[0, 0] = focal_length
36
+ intrinsic[1, 1] = focal_length
37
+ intrinsic[0, 2] = W / 2.0
38
+ intrinsic[1, 2] = H / 2.0
39
+
40
+ if bs > 0:
41
+ intrinsic = intrinsic[None].repeat(bs, axis=0)
42
+
43
+ return torch.from_numpy(intrinsic)
44
+
45
+
46
+ class BaseModule(nn.Module):
47
+ @dataclass
48
+ class Config:
49
+ pass
50
+
51
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
52
+
53
+ def __init__(
54
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
55
+ ) -> None:
56
+ super().__init__()
57
+ self.cfg = parse_structured(self.Config, cfg)
58
+ self.configure(*args, **kwargs)
59
+
60
+ def configure(self, *args, **kwargs) -> None:
61
+ raise NotImplementedError
62
+
63
+
64
+ class ImagePreprocessor:
65
+ def convert_and_resize(
66
+ self,
67
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
68
+ size: int,
69
+ ):
70
+ if isinstance(image, PIL.Image.Image):
71
+ image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
72
+ elif isinstance(image, np.ndarray):
73
+ if image.dtype == np.uint8:
74
+ image = torch.from_numpy(image.astype(np.float32) / 255.0)
75
+ else:
76
+ image = torch.from_numpy(image)
77
+ elif isinstance(image, torch.Tensor):
78
+ pass
79
+
80
+ batched = image.ndim == 4
81
+
82
+ if not batched:
83
+ image = image[None, ...]
84
+ image = F.interpolate(
85
+ image.permute(0, 3, 1, 2),
86
+ (size, size),
87
+ mode="bilinear",
88
+ align_corners=False,
89
+ antialias=True,
90
+ ).permute(0, 2, 3, 1)
91
+ if not batched:
92
+ image = image[0]
93
+ return image
94
+
95
+ def __call__(
96
+ self,
97
+ image: Union[
98
+ PIL.Image.Image,
99
+ np.ndarray,
100
+ torch.FloatTensor,
101
+ List[PIL.Image.Image],
102
+ List[np.ndarray],
103
+ List[torch.FloatTensor],
104
+ ],
105
+ size: int,
106
+ ) -> Any:
107
+ if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
108
+ image = self.convert_and_resize(image, size)
109
+ else:
110
+ if not isinstance(image, list):
111
+ image = [image]
112
+ image = [self.convert_and_resize(im, size) for im in image]
113
+ image = torch.stack(image, dim=0)
114
+ return image
115
+
116
+
117
+ def rays_intersect_bbox(
118
+ rays_o: torch.Tensor,
119
+ rays_d: torch.Tensor,
120
+ radius: float,
121
+ near: float = 0.0,
122
+ valid_thresh: float = 0.01,
123
+ ):
124
+ input_shape = rays_o.shape[:-1]
125
+ rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
126
+ rays_d_valid = torch.where(
127
+ rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
128
+ )
129
+ if type(radius) in [int, float]:
130
+ radius = torch.FloatTensor(
131
+ [[-radius, radius], [-radius, radius], [-radius, radius]]
132
+ ).to(rays_o.device)
133
+ radius = (
134
+ 1.0 - 1.0e-3
135
+ ) * radius # tighten the radius to make sure the intersection point lies in the bounding box
136
+ interx0 = (radius[..., 1] - rays_o) / rays_d_valid
137
+ interx1 = (radius[..., 0] - rays_o) / rays_d_valid
138
+ t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
139
+ t_far = torch.maximum(interx0, interx1).amin(dim=-1)
140
+
141
+ # check wheter a ray intersects the bbox or not
142
+ rays_valid = t_far - t_near > valid_thresh
143
+
144
+ t_near[torch.where(~rays_valid)] = 0.0
145
+ t_far[torch.where(~rays_valid)] = 0.0
146
+
147
+ t_near = t_near.view(*input_shape, 1)
148
+ t_far = t_far.view(*input_shape, 1)
149
+ rays_valid = rays_valid.view(*input_shape)
150
+
151
+ return t_near, t_far, rays_valid
152
+
153
+
154
+ def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
155
+ if chunk_size <= 0:
156
+ return func(*args, **kwargs)
157
+ B = None
158
+ for arg in list(args) + list(kwargs.values()):
159
+ if isinstance(arg, torch.Tensor):
160
+ B = arg.shape[0]
161
+ break
162
+ assert (
163
+ B is not None
164
+ ), "No tensor found in args or kwargs, cannot determine batch size."
165
+ out = defaultdict(list)
166
+ out_type = None
167
+ # max(1, B) to support B == 0
168
+ for i in range(0, max(1, B), chunk_size):
169
+ out_chunk = func(
170
+ *[
171
+ arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
172
+ for arg in args
173
+ ],
174
+ **{
175
+ k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
176
+ for k, arg in kwargs.items()
177
+ },
178
+ )
179
+ if out_chunk is None:
180
+ continue
181
+ out_type = type(out_chunk)
182
+ if isinstance(out_chunk, torch.Tensor):
183
+ out_chunk = {0: out_chunk}
184
+ elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
185
+ chunk_length = len(out_chunk)
186
+ out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
187
+ elif isinstance(out_chunk, dict):
188
+ pass
189
+ else:
190
+ print(
191
+ f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
192
+ )
193
+ exit(1)
194
+ for k, v in out_chunk.items():
195
+ v = v if torch.is_grad_enabled() else v.detach()
196
+ out[k].append(v)
197
+
198
+ if out_type is None:
199
+ return None
200
+
201
+ out_merged: Dict[Any, Optional[torch.Tensor]] = {}
202
+ for k, v in out.items():
203
+ if all([vv is None for vv in v]):
204
+ # allow None in return value
205
+ out_merged[k] = None
206
+ elif all([isinstance(vv, torch.Tensor) for vv in v]):
207
+ out_merged[k] = torch.cat(v, dim=0)
208
+ else:
209
+ raise TypeError(
210
+ f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
211
+ )
212
+
213
+ if out_type is torch.Tensor:
214
+ return out_merged[0]
215
+ elif out_type in [tuple, list]:
216
+ return out_type([out_merged[i] for i in range(chunk_length)])
217
+ elif out_type is dict:
218
+ return out_merged
219
+
220
+
221
+ ValidScale = Union[Tuple[float, float], torch.FloatTensor]
222
+
223
+
224
+ def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
225
+ if inp_scale is None:
226
+ inp_scale = (0, 1)
227
+ if tgt_scale is None:
228
+ tgt_scale = (0, 1)
229
+ if isinstance(tgt_scale, torch.FloatTensor):
230
+ assert dat.shape[-1] == tgt_scale.shape[-1]
231
+ dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
232
+ dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
233
+ return dat
234
+
235
+
236
+ def get_activation(name) -> Callable:
237
+ if name is None:
238
+ return lambda x: x
239
+ name = name.lower()
240
+ if name == "none":
241
+ return lambda x: x
242
+ elif name == "exp":
243
+ return lambda x: torch.exp(x)
244
+ elif name == "sigmoid":
245
+ return lambda x: torch.sigmoid(x)
246
+ elif name == "tanh":
247
+ return lambda x: torch.tanh(x)
248
+ elif name == "softplus":
249
+ return lambda x: F.softplus(x)
250
+ else:
251
+ try:
252
+ return getattr(F, name)
253
+ except AttributeError:
254
+ raise ValueError(f"Unknown activation function: {name}")
255
+
256
+
257
+ def get_ray_directions(
258
+ H: int,
259
+ W: int,
260
+ focal: Union[float, Tuple[float, float]],
261
+ principal: Optional[Tuple[float, float]] = None,
262
+ use_pixel_centers: bool = True,
263
+ normalize: bool = True,
264
+ ) -> torch.FloatTensor:
265
+ """
266
+ Get ray directions for all pixels in camera coordinate.
267
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
268
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
269
+
270
+ Inputs:
271
+ H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
272
+ Outputs:
273
+ directions: (H, W, 3), the direction of the rays in camera coordinate
274
+ """
275
+ pixel_center = 0.5 if use_pixel_centers else 0
276
+
277
+ if isinstance(focal, float):
278
+ fx, fy = focal, focal
279
+ cx, cy = W / 2, H / 2
280
+ else:
281
+ fx, fy = focal
282
+ assert principal is not None
283
+ cx, cy = principal
284
+
285
+ i, j = torch.meshgrid(
286
+ torch.arange(W, dtype=torch.float32) + pixel_center,
287
+ torch.arange(H, dtype=torch.float32) + pixel_center,
288
+ indexing="xy",
289
+ )
290
+
291
+ directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
292
+
293
+ if normalize:
294
+ directions = F.normalize(directions, dim=-1)
295
+
296
+ return directions
297
+
298
+
299
+ def get_rays(
300
+ directions,
301
+ c2w,
302
+ keepdim=False,
303
+ normalize=False,
304
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
305
+ # Rotate ray directions from camera coordinate to the world coordinate
306
+ assert directions.shape[-1] == 3
307
+
308
+ if directions.ndim == 2: # (N_rays, 3)
309
+ if c2w.ndim == 2: # (4, 4)
310
+ c2w = c2w[None, :, :]
311
+ assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
312
+ rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
313
+ rays_o = c2w[:, :3, 3].expand(rays_d.shape)
314
+ elif directions.ndim == 3: # (H, W, 3)
315
+ assert c2w.ndim in [2, 3]
316
+ if c2w.ndim == 2: # (4, 4)
317
+ rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
318
+ -1
319
+ ) # (H, W, 3)
320
+ rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
321
+ elif c2w.ndim == 3: # (B, 4, 4)
322
+ rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
323
+ -1
324
+ ) # (B, H, W, 3)
325
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
326
+ elif directions.ndim == 4: # (B, H, W, 3)
327
+ assert c2w.ndim == 3 # (B, 4, 4)
328
+ rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
329
+ -1
330
+ ) # (B, H, W, 3)
331
+ rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
332
+
333
+ if normalize:
334
+ rays_d = F.normalize(rays_d, dim=-1)
335
+ if not keepdim:
336
+ rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
337
+
338
+ return rays_o, rays_d
339
+
340
+
341
+ def get_spherical_cameras(
342
+ n_views: int,
343
+ elevation_deg: float,
344
+ camera_distance: float,
345
+ fovy_deg: float,
346
+ height: int,
347
+ width: int,
348
+ ):
349
+ # Use 0 to 360*(n_views-1)/n_views to avoid duplicate first/last position
350
+ # This ensures full 360-degree coverage without overlap
351
+ azimuth_deg = torch.linspace(0, 360.0 * (n_views - 1) / n_views, n_views)
352
+ elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
353
+ camera_distances = torch.full_like(elevation_deg, camera_distance)
354
+
355
+ elevation = elevation_deg * math.pi / 180
356
+ azimuth = azimuth_deg * math.pi / 180
357
+
358
+ # convert spherical coordinates to cartesian coordinates
359
+ # right hand coordinate system, x back, y right, z up
360
+ # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
361
+ camera_positions = torch.stack(
362
+ [
363
+ camera_distances * torch.cos(elevation) * torch.cos(azimuth),
364
+ camera_distances * torch.cos(elevation) * torch.sin(azimuth),
365
+ camera_distances * torch.sin(elevation),
366
+ ],
367
+ dim=-1,
368
+ )
369
+
370
+ # default scene center at origin
371
+ center = torch.zeros_like(camera_positions)
372
+ # default camera up direction as +z
373
+ up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
374
+
375
+ fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
376
+
377
+ lookat = F.normalize(center - camera_positions, dim=-1)
378
+ right = F.normalize(torch.cross(lookat, up), dim=-1)
379
+ up = F.normalize(torch.cross(right, lookat), dim=-1)
380
+ c2w3x4 = torch.cat(
381
+ [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
382
+ dim=-1,
383
+ )
384
+ c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
385
+ c2w[:, 3, 3] = 1.0
386
+
387
+ # get directions by dividing directions_unit_focal by focal length
388
+ focal_length = 0.5 * height / torch.tan(0.5 * fovy)
389
+ directions_unit_focal = get_ray_directions(
390
+ H=height,
391
+ W=width,
392
+ focal=1.0,
393
+ )
394
+ directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
395
+ directions[:, :, :, :2] = (
396
+ directions[:, :, :, :2] / focal_length[:, None, None, None]
397
+ )
398
+ # must use normalize=True to normalize directions here
399
+ rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
400
+
401
+ return rays_o, rays_d
402
+
403
+
404
+ def remove_background(
405
+ image: PIL.Image.Image,
406
+ rembg_session: Any = None,
407
+ force: bool = False,
408
+ **rembg_kwargs,
409
+ ) -> PIL.Image.Image:
410
+ do_remove = True
411
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
412
+ do_remove = False
413
+ do_remove = do_remove or force
414
+ if do_remove:
415
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
416
+ return image
417
+
418
+
419
+ def resize_foreground(
420
+ image: PIL.Image.Image,
421
+ ratio: float,
422
+ ) -> PIL.Image.Image:
423
+ image = np.array(image)
424
+ assert image.shape[-1] == 4
425
+ alpha = np.where(image[..., 3] > 0)
426
+ y1, y2, x1, x2 = (
427
+ alpha[0].min(),
428
+ alpha[0].max(),
429
+ alpha[1].min(),
430
+ alpha[1].max(),
431
+ )
432
+ # crop the foreground
433
+ fg = image[y1:y2, x1:x2]
434
+ # pad to square
435
+ size = max(fg.shape[0], fg.shape[1])
436
+ ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
437
+ ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
438
+ new_image = np.pad(
439
+ fg,
440
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
441
+ mode="constant",
442
+ constant_values=((0, 0), (0, 0), (0, 0)),
443
+ )
444
+
445
+ # compute padding according to the ratio
446
+ new_size = int(new_image.shape[0] / ratio)
447
+ # pad to size, double side
448
+ ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
449
+ ph1, pw1 = new_size - size - ph0, new_size - size - pw0
450
+ new_image = np.pad(
451
+ new_image,
452
+ ((ph0, ph1), (pw0, pw1), (0, 0)),
453
+ mode="constant",
454
+ constant_values=((0, 0), (0, 0), (0, 0)),
455
+ )
456
+ new_image = PIL.Image.fromarray(new_image)
457
+ return new_image
458
+
459
+
460
+ def save_video(
461
+ frames: List[PIL.Image.Image],
462
+ output_path: str,
463
+ fps: int = 30,
464
+ ):
465
+ # use imageio to save video
466
+ frames = [np.array(frame) for frame in frames]
467
+ writer = imageio.get_writer(output_path, fps=fps)
468
+ for frame in frames:
469
+ writer.append_data(frame)
470
+ writer.close()
471
+
472
+
473
+ def to_gradio_3d_orientation(mesh):
474
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
475
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
476
+ return mesh
477
+
478
+
479
+ def to_standard_3d_orientation(mesh):
480
+ """
481
+ Convert mesh to standard 3D viewer orientation (Y-up, Z-forward).
482
+ This is a more standard orientation that works better with most 3D viewers.
483
+ """
484
+ # Rotate -90 degrees around X axis (to make Y up instead of Z)
485
+ mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
486
+ return mesh
487
+
488
+
489
+ def apply_mesh_orientation(mesh, orientation="standard"):
490
+ """
491
+ Apply orientation transformation to mesh.
492
+
493
+ Args:
494
+ mesh: Trimesh mesh object
495
+ orientation: Orientation type
496
+ - "standard": Standard 3D viewer orientation (Y-up, Z-forward)
497
+ - "gradio": Gradio 3D viewer orientation
498
+ - "none": No transformation (original orientation)
499
+
500
+ Returns:
501
+ Transformed mesh
502
+ """
503
+ if orientation == "standard":
504
+ return to_standard_3d_orientation(mesh)
505
+ elif orientation == "gradio":
506
+ return to_gradio_3d_orientation(mesh)
507
+ elif orientation == "none":
508
+ return mesh
509
+ else:
510
+ raise ValueError(f"Unknown orientation: {orientation}. Must be 'standard', 'gradio', or 'none'")