Spaces:
Sleeping
Sleeping
Upload 41 files
Browse files- .dockerignore +55 -0
- .gitattributes +2 -35
- .gitignore +164 -0
- API_README.md +209 -0
- DEPLOYMENT_GUIDE.md +271 -0
- Dockerfile +56 -0
- LICENSE +21 -0
- REACT_INTEGRATION.md +549 -0
- README.md +143 -11
- README_DOCKER.md +158 -0
- README_HF_SPACES.md +143 -0
- api_example.html +194 -0
- api_server.py +403 -0
- app.py +171 -0
- deploy_colab.ipynb +268 -0
- docker-compose.yml +31 -0
- gradio_app.py +187 -0
- requirements.txt +16 -0
- run.py +197 -0
- tsr/__pycache__/bake_texture.cpython-313.pyc +0 -0
- tsr/__pycache__/system.cpython-313.pyc +0 -0
- tsr/__pycache__/utils.cpython-313.pyc +0 -0
- tsr/bake_texture.py +191 -0
- tsr/models/__pycache__/isosurface.cpython-313.pyc +0 -0
- tsr/models/__pycache__/nerf_renderer.cpython-313.pyc +0 -0
- tsr/models/__pycache__/network_utils.cpython-313.pyc +0 -0
- tsr/models/isosurface.py +64 -0
- tsr/models/nerf_renderer.py +180 -0
- tsr/models/network_utils.py +124 -0
- tsr/models/tokenizers/__pycache__/image.cpython-313.pyc +0 -0
- tsr/models/tokenizers/__pycache__/triplane.cpython-313.pyc +0 -0
- tsr/models/tokenizers/image.py +66 -0
- tsr/models/tokenizers/triplane.py +45 -0
- tsr/models/transformer/__pycache__/attention.cpython-313.pyc +0 -0
- tsr/models/transformer/__pycache__/basic_transformer_block.cpython-313.pyc +0 -0
- tsr/models/transformer/__pycache__/transformer_1d.cpython-313.pyc +0 -0
- tsr/models/transformer/attention.py +653 -0
- tsr/models/transformer/basic_transformer_block.py +334 -0
- tsr/models/transformer/transformer_1d.py +219 -0
- tsr/system.py +205 -0
- tsr/utils.py +510 -0
.dockerignore
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
*.egg
|
| 11 |
+
|
| 12 |
+
# Virtual environments
|
| 13 |
+
venv/
|
| 14 |
+
env/
|
| 15 |
+
ENV/
|
| 16 |
+
|
| 17 |
+
# IDE
|
| 18 |
+
.vscode/
|
| 19 |
+
.idea/
|
| 20 |
+
*.swp
|
| 21 |
+
*.swo
|
| 22 |
+
*~
|
| 23 |
+
|
| 24 |
+
# Git
|
| 25 |
+
.git/
|
| 26 |
+
.gitignore
|
| 27 |
+
|
| 28 |
+
# Docker
|
| 29 |
+
Dockerfile
|
| 30 |
+
.dockerignore
|
| 31 |
+
|
| 32 |
+
# Output files
|
| 33 |
+
output/
|
| 34 |
+
*.obj
|
| 35 |
+
*.glb
|
| 36 |
+
*.zip
|
| 37 |
+
*.png
|
| 38 |
+
*.jpg
|
| 39 |
+
*.jpeg
|
| 40 |
+
|
| 41 |
+
# Logs
|
| 42 |
+
*.log
|
| 43 |
+
|
| 44 |
+
# OS
|
| 45 |
+
.DS_Store
|
| 46 |
+
Thumbs.db
|
| 47 |
+
|
| 48 |
+
# Model cache (will be downloaded at runtime)
|
| 49 |
+
.cache/
|
| 50 |
+
|
| 51 |
+
# Temporary files
|
| 52 |
+
*.tmp
|
| 53 |
+
temp/
|
| 54 |
+
|
| 55 |
+
|
.gitattributes
CHANGED
|
@@ -1,35 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
| 161 |
+
|
| 162 |
+
# default output directory
|
| 163 |
+
output/
|
| 164 |
+
outputs/
|
API_README.md
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TripoSR REST API
|
| 2 |
+
|
| 3 |
+
A FastAPI-based REST API server for TripoSR 3D mesh generation.
|
| 4 |
+
|
| 5 |
+
## Starting the Server
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
python api_server.py
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
Or with uvicorn directly:
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
uvicorn api_server:app --host 0.0.0.0 --port 8000
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
The API will be available at `http://localhost:8000`
|
| 18 |
+
|
| 19 |
+
## API Endpoints
|
| 20 |
+
|
| 21 |
+
### 1. Health Check
|
| 22 |
+
|
| 23 |
+
**GET** `/health`
|
| 24 |
+
|
| 25 |
+
Check if the API is running and get device information.
|
| 26 |
+
|
| 27 |
+
**Response:**
|
| 28 |
+
```json
|
| 29 |
+
{
|
| 30 |
+
"status": "healthy",
|
| 31 |
+
"device": "cuda:0",
|
| 32 |
+
"cuda_available": true
|
| 33 |
+
}
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### 2. Generate Mesh (File Download)
|
| 37 |
+
|
| 38 |
+
**POST** `/generate`
|
| 39 |
+
|
| 40 |
+
Generate a 3D mesh from an uploaded image and download the mesh file.
|
| 41 |
+
|
| 42 |
+
**Parameters:**
|
| 43 |
+
- `image` (file, required): Image file (PNG, JPG, JPEG)
|
| 44 |
+
- `do_remove_background` (boolean, default: true): Whether to remove background
|
| 45 |
+
- `foreground_ratio` (float, default: 0.85): Ratio of foreground size (0.5-1.0)
|
| 46 |
+
- `mc_resolution` (int, default: 256): Marching cubes resolution (128, 160, 192, 224, 256, 288, 320)
|
| 47 |
+
- `format` (string, default: "obj"): Output format - "obj" or "glb"
|
| 48 |
+
|
| 49 |
+
**Response:** Mesh file download
|
| 50 |
+
|
| 51 |
+
**Example (cURL):**
|
| 52 |
+
```bash
|
| 53 |
+
curl -X POST "http://localhost:8000/generate" \
|
| 54 |
+
-F "image=@chair.png" \
|
| 55 |
+
-F "do_remove_background=true" \
|
| 56 |
+
-F "foreground_ratio=0.85" \
|
| 57 |
+
-F "mc_resolution=256" \
|
| 58 |
+
-F "format=obj" \
|
| 59 |
+
--output mesh.obj
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### 3. Generate Mesh (Base64)
|
| 63 |
+
|
| 64 |
+
**POST** `/generate-base64`
|
| 65 |
+
|
| 66 |
+
Generate a 3D mesh and return as base64 encoded string.
|
| 67 |
+
|
| 68 |
+
**Parameters:** Same as `/generate`
|
| 69 |
+
|
| 70 |
+
**Response:**
|
| 71 |
+
```json
|
| 72 |
+
{
|
| 73 |
+
"success": true,
|
| 74 |
+
"format": "obj",
|
| 75 |
+
"mesh": "base64_encoded_mesh_data...",
|
| 76 |
+
"size": 1234567
|
| 77 |
+
}
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Frontend Integration Examples
|
| 81 |
+
|
| 82 |
+
### JavaScript/Fetch
|
| 83 |
+
|
| 84 |
+
```javascript
|
| 85 |
+
const formData = new FormData();
|
| 86 |
+
formData.append('image', imageFile);
|
| 87 |
+
formData.append('do_remove_background', true);
|
| 88 |
+
formData.append('foreground_ratio', 0.85);
|
| 89 |
+
formData.append('mc_resolution', 256);
|
| 90 |
+
formData.append('format', 'obj');
|
| 91 |
+
|
| 92 |
+
const response = await fetch('http://localhost:8000/generate', {
|
| 93 |
+
method: 'POST',
|
| 94 |
+
body: formData
|
| 95 |
+
});
|
| 96 |
+
|
| 97 |
+
if (response.ok) {
|
| 98 |
+
const blob = await response.blob();
|
| 99 |
+
const url = window.URL.createObjectURL(blob);
|
| 100 |
+
// Download or use the mesh file
|
| 101 |
+
const a = document.createElement('a');
|
| 102 |
+
a.href = url;
|
| 103 |
+
a.download = 'mesh.obj';
|
| 104 |
+
a.click();
|
| 105 |
+
}
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
### React Example
|
| 109 |
+
|
| 110 |
+
```jsx
|
| 111 |
+
import { useState } from 'react';
|
| 112 |
+
|
| 113 |
+
function MeshGenerator() {
|
| 114 |
+
const [loading, setLoading] = useState(false);
|
| 115 |
+
|
| 116 |
+
const generateMesh = async (imageFile) => {
|
| 117 |
+
setLoading(true);
|
| 118 |
+
const formData = new FormData();
|
| 119 |
+
formData.append('image', imageFile);
|
| 120 |
+
formData.append('do_remove_background', true);
|
| 121 |
+
formData.append('foreground_ratio', 0.85);
|
| 122 |
+
formData.append('mc_resolution', 256);
|
| 123 |
+
formData.append('format', 'obj');
|
| 124 |
+
|
| 125 |
+
try {
|
| 126 |
+
const response = await fetch('http://localhost:8000/generate', {
|
| 127 |
+
method: 'POST',
|
| 128 |
+
body: formData
|
| 129 |
+
});
|
| 130 |
+
|
| 131 |
+
if (response.ok) {
|
| 132 |
+
const blob = await response.blob();
|
| 133 |
+
// Handle the mesh file
|
| 134 |
+
const url = window.URL.createObjectURL(blob);
|
| 135 |
+
// Download or display
|
| 136 |
+
}
|
| 137 |
+
} catch (error) {
|
| 138 |
+
console.error('Error:', error);
|
| 139 |
+
} finally {
|
| 140 |
+
setLoading(false);
|
| 141 |
+
}
|
| 142 |
+
};
|
| 143 |
+
|
| 144 |
+
return (
|
| 145 |
+
<div>
|
| 146 |
+
<input
|
| 147 |
+
type="file"
|
| 148 |
+
accept="image/*"
|
| 149 |
+
onChange={(e) => generateMesh(e.target.files[0])}
|
| 150 |
+
/>
|
| 151 |
+
{loading && <p>Generating mesh...</p>}
|
| 152 |
+
</div>
|
| 153 |
+
);
|
| 154 |
+
}
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
### Python Client Example
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
import requests
|
| 161 |
+
|
| 162 |
+
url = "http://localhost:8000/generate"
|
| 163 |
+
|
| 164 |
+
with open("chair.png", "rb") as f:
|
| 165 |
+
files = {"image": f}
|
| 166 |
+
data = {
|
| 167 |
+
"do_remove_background": True,
|
| 168 |
+
"foreground_ratio": 0.85,
|
| 169 |
+
"mc_resolution": 256,
|
| 170 |
+
"format": "obj"
|
| 171 |
+
}
|
| 172 |
+
response = requests.post(url, files=files, data=data)
|
| 173 |
+
|
| 174 |
+
if response.status_code == 200:
|
| 175 |
+
with open("mesh.obj", "wb") as out:
|
| 176 |
+
out.write(response.content)
|
| 177 |
+
print("Mesh saved to mesh.obj")
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
## CORS Configuration
|
| 181 |
+
|
| 182 |
+
The API is configured to allow CORS from all origins by default. For production, update the `allow_origins` in `api_server.py`:
|
| 183 |
+
|
| 184 |
+
```python
|
| 185 |
+
app.add_middleware(
|
| 186 |
+
CORSMiddleware,
|
| 187 |
+
allow_origins=["https://your-frontend-domain.com"], # Your frontend URL
|
| 188 |
+
allow_credentials=True,
|
| 189 |
+
allow_methods=["*"],
|
| 190 |
+
allow_headers=["*"],
|
| 191 |
+
)
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
## Performance Notes
|
| 195 |
+
|
| 196 |
+
- Model initialization takes ~18-20 seconds on first request
|
| 197 |
+
- Mesh generation typically takes 30-60 seconds depending on resolution
|
| 198 |
+
- GPU (CUDA) is recommended for faster processing
|
| 199 |
+
- Consider implementing request queuing for production use
|
| 200 |
+
|
| 201 |
+
## Error Handling
|
| 202 |
+
|
| 203 |
+
All endpoints return appropriate HTTP status codes:
|
| 204 |
+
- `200`: Success
|
| 205 |
+
- `400`: Bad request (invalid parameters)
|
| 206 |
+
- `500`: Server error (model processing failed)
|
| 207 |
+
|
| 208 |
+
Error responses include a JSON body with a `detail` field describing the error.
|
| 209 |
+
|
DEPLOYMENT_GUIDE.md
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Free Cloud Deployment Guide for TripoSR API
|
| 2 |
+
|
| 3 |
+
This guide covers multiple **FREE** options to deploy your TripoSR API in the cloud.
|
| 4 |
+
|
| 5 |
+
## 📋 Table of Contents
|
| 6 |
+
1. [Hugging Face Spaces (Recommended)](#1-hugging-face-spaces-recommended)
|
| 7 |
+
2. [Google Colab](#2-google-colab)
|
| 8 |
+
3. [Render.com (CPU Only)](#3-rendercom-cpu-only)
|
| 9 |
+
4. [Railway.app](#4-railwayapp)
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 1. Hugging Face Spaces (Recommended) ⭐
|
| 14 |
+
|
| 15 |
+
**Best option** for this project - offers free GPU access and is designed for ML models.
|
| 16 |
+
|
| 17 |
+
### Features:
|
| 18 |
+
- ✅ Free GPU (T4) available
|
| 19 |
+
- ✅ Persistent deployment
|
| 20 |
+
- ✅ Built-in CI/CD
|
| 21 |
+
- ✅ Public API endpoint
|
| 22 |
+
- ✅ Great for ML models
|
| 23 |
+
|
| 24 |
+
### Setup Steps:
|
| 25 |
+
|
| 26 |
+
1. **Create a Hugging Face account** at [huggingface.co](https://huggingface.co)
|
| 27 |
+
|
| 28 |
+
2. **Create a new Space:**
|
| 29 |
+
- Go to https://huggingface.co/spaces
|
| 30 |
+
- Click "Create new Space"
|
| 31 |
+
- Name: `triposr-api`
|
| 32 |
+
- License: MIT
|
| 33 |
+
- Select SDK: **Docker**
|
| 34 |
+
- Hardware: **CPU basic** (start here, upgrade to GPU if needed)
|
| 35 |
+
|
| 36 |
+
3. **Clone your Space repository:**
|
| 37 |
+
```bash
|
| 38 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/triposr-api
|
| 39 |
+
cd triposr-api
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
4. **Copy these files to the Space:**
|
| 43 |
+
- Copy `Dockerfile.huggingface` as `Dockerfile`
|
| 44 |
+
- Copy `api_server.py`
|
| 45 |
+
- Copy `requirements.txt`
|
| 46 |
+
- Copy the entire `tsr/` directory
|
| 47 |
+
- Copy `README.md`
|
| 48 |
+
|
| 49 |
+
5. **Create a `README.md` header** (required by HF Spaces):
|
| 50 |
+
```markdown
|
| 51 |
+
---
|
| 52 |
+
title: TripoSR API
|
| 53 |
+
emoji: 🎨
|
| 54 |
+
colorFrom: blue
|
| 55 |
+
colorTo: purple
|
| 56 |
+
sdk: docker
|
| 57 |
+
pinned: false
|
| 58 |
+
license: mit
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
# TripoSR API
|
| 62 |
+
|
| 63 |
+
Fast 3D reconstruction from a single image.
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
6. **Push to Hugging Face:**
|
| 67 |
+
```bash
|
| 68 |
+
git add .
|
| 69 |
+
git commit -m "Initial deployment"
|
| 70 |
+
git push
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
7. **Your API will be available at:**
|
| 74 |
+
```
|
| 75 |
+
https://YOUR_USERNAME-triposr-api.hf.space
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Upgrade to GPU (if needed):
|
| 79 |
+
- Go to your Space settings
|
| 80 |
+
- Under "Hardware", select **T4 small** (free tier)
|
| 81 |
+
- The Space will rebuild automatically
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
## 2. Google Colab
|
| 86 |
+
|
| 87 |
+
**Good for testing** - Free GPU but sessions expire after inactivity.
|
| 88 |
+
|
| 89 |
+
### Features:
|
| 90 |
+
- ✅ Free GPU (T4/K80)
|
| 91 |
+
- ❌ Sessions expire (12-hour limit)
|
| 92 |
+
- ❌ Not persistent
|
| 93 |
+
- ✅ Good for testing/demos
|
| 94 |
+
|
| 95 |
+
### Setup Steps:
|
| 96 |
+
|
| 97 |
+
1. **Upload the Colab notebook:**
|
| 98 |
+
- Open [Google Colab](https://colab.research.google.com)
|
| 99 |
+
- Upload `deploy_colab.ipynb` (provided in this repo)
|
| 100 |
+
|
| 101 |
+
2. **Run the notebook:**
|
| 102 |
+
- Enable GPU: Runtime → Change runtime type → GPU
|
| 103 |
+
- Run all cells
|
| 104 |
+
- The notebook will install dependencies and start the server
|
| 105 |
+
|
| 106 |
+
3. **Access via ngrok tunnel:**
|
| 107 |
+
- The notebook creates a public URL using ngrok
|
| 108 |
+
- URL will be displayed in the output
|
| 109 |
+
- Example: `https://abc123.ngrok.io`
|
| 110 |
+
|
| 111 |
+
**Note:** The URL changes every time you restart the notebook.
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## 3. Render.com (CPU Only)
|
| 116 |
+
|
| 117 |
+
**Limited functionality** - Free tier is CPU only, so 3D generation will be VERY slow.
|
| 118 |
+
|
| 119 |
+
### Features:
|
| 120 |
+
- ✅ Free tier available
|
| 121 |
+
- ✅ Persistent deployment
|
| 122 |
+
- ❌ CPU only (slow inference)
|
| 123 |
+
- ✅ Auto-deploy from GitHub
|
| 124 |
+
|
| 125 |
+
### Setup Steps:
|
| 126 |
+
|
| 127 |
+
1. **Push code to GitHub:**
|
| 128 |
+
```bash
|
| 129 |
+
git init
|
| 130 |
+
git add .
|
| 131 |
+
git commit -m "Initial commit"
|
| 132 |
+
git remote add origin YOUR_GITHUB_REPO_URL
|
| 133 |
+
git push -u origin main
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
2. **Create Render account** at [render.com](https://render.com)
|
| 137 |
+
|
| 138 |
+
3. **Create a new Web Service:**
|
| 139 |
+
- Click "New +" → "Web Service"
|
| 140 |
+
- Connect your GitHub repository
|
| 141 |
+
- Name: `triposr-api`
|
| 142 |
+
- Environment: **Docker**
|
| 143 |
+
- Plan: **Free**
|
| 144 |
+
|
| 145 |
+
4. **Configure:**
|
| 146 |
+
- Build Command: (leave empty - using Dockerfile)
|
| 147 |
+
- Start Command: (leave empty - using Dockerfile CMD)
|
| 148 |
+
|
| 149 |
+
5. **Deploy:**
|
| 150 |
+
- Click "Create Web Service"
|
| 151 |
+
- Wait for build to complete (~10-15 minutes)
|
| 152 |
+
|
| 153 |
+
6. **Your API will be available at:**
|
| 154 |
+
```
|
| 155 |
+
https://triposr-api.onrender.com
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
**⚠️ Warning:** CPU-only inference will be 10-50x slower than GPU!
|
| 159 |
+
|
| 160 |
+
---
|
| 161 |
+
|
| 162 |
+
## 4. Railway.app
|
| 163 |
+
|
| 164 |
+
**Trial credits** - $5 free credits, then paid.
|
| 165 |
+
|
| 166 |
+
### Features:
|
| 167 |
+
- ✅ $5 free trial credits
|
| 168 |
+
- ✅ Easy deployment
|
| 169 |
+
- ❌ No free GPU
|
| 170 |
+
- ✅ Good for CPU testing
|
| 171 |
+
|
| 172 |
+
### Setup Steps:
|
| 173 |
+
|
| 174 |
+
1. **Create Railway account** at [railway.app](https://railway.app)
|
| 175 |
+
|
| 176 |
+
2. **Create new project:**
|
| 177 |
+
- Click "New Project"
|
| 178 |
+
- Select "Deploy from GitHub repo"
|
| 179 |
+
- Connect your repository
|
| 180 |
+
|
| 181 |
+
3. **Configure:**
|
| 182 |
+
- Railway auto-detects Dockerfile
|
| 183 |
+
- Add environment variables if needed
|
| 184 |
+
|
| 185 |
+
4. **Deploy:**
|
| 186 |
+
- Railway builds and deploys automatically
|
| 187 |
+
- Get your public URL from the dashboard
|
| 188 |
+
|
| 189 |
+
---
|
| 190 |
+
|
| 191 |
+
## 🎯 Recommendation
|
| 192 |
+
|
| 193 |
+
**For production use:** Hugging Face Spaces with GPU
|
| 194 |
+
- Best balance of features, performance, and cost
|
| 195 |
+
- Designed for ML models
|
| 196 |
+
- Free GPU tier available
|
| 197 |
+
|
| 198 |
+
**For testing/demos:** Google Colab
|
| 199 |
+
- Quick setup
|
| 200 |
+
- Free GPU
|
| 201 |
+
- Good for temporary demos
|
| 202 |
+
|
| 203 |
+
**For CPU-only:** Render.com
|
| 204 |
+
- Persistent deployment
|
| 205 |
+
- But very slow for 3D generation
|
| 206 |
+
|
| 207 |
+
---
|
| 208 |
+
|
| 209 |
+
## 📊 Comparison Table
|
| 210 |
+
|
| 211 |
+
| Platform | GPU | Persistent | Free Tier | Best For |
|
| 212 |
+
|----------|-----|------------|-----------|----------|
|
| 213 |
+
| **Hugging Face Spaces** | ✅ T4 | ✅ | ✅ | Production |
|
| 214 |
+
| **Google Colab** | ✅ T4/K80 | ❌ | ✅ | Testing |
|
| 215 |
+
| **Render.com** | ❌ | ✅ | ✅ | CPU demos |
|
| 216 |
+
| **Railway.app** | ❌ | ✅ | $5 credits | Trial |
|
| 217 |
+
|
| 218 |
+
---
|
| 219 |
+
|
| 220 |
+
## 🔧 Testing Your Deployment
|
| 221 |
+
|
| 222 |
+
Once deployed, test your API:
|
| 223 |
+
|
| 224 |
+
```bash
|
| 225 |
+
# Health check
|
| 226 |
+
curl https://YOUR_DEPLOYMENT_URL/health
|
| 227 |
+
|
| 228 |
+
# Generate 3D model
|
| 229 |
+
curl -X POST https://YOUR_DEPLOYMENT_URL/generate \
|
| 230 |
+
-F "image=@test_image.png" \
|
| 231 |
+
-F "format=obj" \
|
| 232 |
+
-F "bake_texture_flag=true" \
|
| 233 |
+
-o output.zip
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
---
|
| 237 |
+
|
| 238 |
+
## 📝 Notes
|
| 239 |
+
|
| 240 |
+
- **Model size:** The TripoSR model is ~1.5GB and will be downloaded on first run
|
| 241 |
+
- **Memory requirements:** Minimum 8GB RAM, 6GB VRAM for GPU
|
| 242 |
+
- **Cold starts:** First request may take 30-60 seconds to load the model
|
| 243 |
+
- **Rate limits:** Free tiers may have rate limits or usage quotas
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
|
| 247 |
+
## 🆘 Troubleshooting
|
| 248 |
+
|
| 249 |
+
### "Out of memory" errors
|
| 250 |
+
- Reduce `mc_resolution` parameter (default: 256)
|
| 251 |
+
- Use smaller images
|
| 252 |
+
- Upgrade to larger instance
|
| 253 |
+
|
| 254 |
+
### Slow generation
|
| 255 |
+
- Ensure GPU is enabled
|
| 256 |
+
- Check if running on CPU (much slower)
|
| 257 |
+
- Monitor instance resources
|
| 258 |
+
|
| 259 |
+
### Build failures
|
| 260 |
+
- Check Docker logs
|
| 261 |
+
- Ensure all dependencies are in `requirements.txt`
|
| 262 |
+
- Verify CUDA compatibility
|
| 263 |
+
|
| 264 |
+
---
|
| 265 |
+
|
| 266 |
+
## 📚 Additional Resources
|
| 267 |
+
|
| 268 |
+
- [TripoSR Paper](https://arxiv.org/abs/2403.02151)
|
| 269 |
+
- [Hugging Face Spaces Docs](https://huggingface.co/docs/hub/spaces)
|
| 270 |
+
- [Render Docs](https://render.com/docs)
|
| 271 |
+
- [Railway Docs](https://docs.railway.app)
|
Dockerfile
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces optimized Dockerfile
|
| 2 |
+
# This uses a lighter base image suitable for HF Spaces
|
| 3 |
+
|
| 4 |
+
FROM python:3.10-slim
|
| 5 |
+
|
| 6 |
+
# Set environment variables
|
| 7 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 8 |
+
ENV PYTHONUNBUFFERED=1
|
| 9 |
+
ENV TRANSFORMERS_CACHE=/tmp/transformers_cache
|
| 10 |
+
ENV HF_HOME=/tmp/hf_home
|
| 11 |
+
|
| 12 |
+
# Install system dependencies
|
| 13 |
+
RUN apt-get update && apt-get install -y \
|
| 14 |
+
build-essential \
|
| 15 |
+
gcc \
|
| 16 |
+
g++ \
|
| 17 |
+
git \
|
| 18 |
+
ffmpeg \
|
| 19 |
+
libsm6 \
|
| 20 |
+
libxext6 \
|
| 21 |
+
libxrender-dev \
|
| 22 |
+
libgomp1 \
|
| 23 |
+
curl \
|
| 24 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 25 |
+
|
| 26 |
+
# Set up the application directory
|
| 27 |
+
WORKDIR /app
|
| 28 |
+
|
| 29 |
+
# Copy requirements file first for better Docker layer caching
|
| 30 |
+
COPY requirements.txt /app/
|
| 31 |
+
|
| 32 |
+
# Upgrade pip and setuptools
|
| 33 |
+
RUN pip install --no-cache-dir --upgrade pip setuptools wheel
|
| 34 |
+
|
| 35 |
+
# Install PyTorch CPU version (HF Spaces will use CPU by default, GPU if upgraded)
|
| 36 |
+
# Using CPU version to reduce image size - will auto-detect GPU if available
|
| 37 |
+
RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
| 38 |
+
|
| 39 |
+
# Install other dependencies
|
| 40 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 41 |
+
|
| 42 |
+
# Copy the rest of the application code
|
| 43 |
+
COPY . /app
|
| 44 |
+
|
| 45 |
+
# Create output directory for temporary files
|
| 46 |
+
RUN mkdir -p /app/output
|
| 47 |
+
|
| 48 |
+
# Expose the port (HF Spaces uses 7860 by default)
|
| 49 |
+
EXPOSE 7860
|
| 50 |
+
|
| 51 |
+
# Health check
|
| 52 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=120s --retries=3 \
|
| 53 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 54 |
+
|
| 55 |
+
# Run the server on port 7860 (required by HF Spaces)
|
| 56 |
+
CMD ["python", "-m", "uvicorn", "api_server:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Tripo AI & Stability AI
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
REACT_INTEGRATION.md
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TripoSR API - React Integration Guide
|
| 2 |
+
|
| 3 |
+
This guide shows how to integrate TripoSR API into your React + Supabase project.
|
| 4 |
+
|
| 5 |
+
## 1. Update CORS in API Server
|
| 6 |
+
|
| 7 |
+
First, update `api_server.py` to allow your React app's origin:
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
# In api_server.py, update the CORS middleware:
|
| 11 |
+
app.add_middleware(
|
| 12 |
+
CORSMiddleware,
|
| 13 |
+
allow_origins=[
|
| 14 |
+
"http://localhost:3000", # React dev server
|
| 15 |
+
"http://localhost:5173", # Vite dev server
|
| 16 |
+
"http://localhost:8080", # Other common ports
|
| 17 |
+
# Add your production domain here
|
| 18 |
+
],
|
| 19 |
+
allow_credentials=True,
|
| 20 |
+
allow_methods=["*"],
|
| 21 |
+
allow_headers=["*"],
|
| 22 |
+
)
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## 2. React Hook for TripoSR API
|
| 26 |
+
|
| 27 |
+
Create a custom hook: `src/hooks/useTripoSR.js` or `src/hooks/useTripoSR.ts`
|
| 28 |
+
|
| 29 |
+
### TypeScript Version:
|
| 30 |
+
|
| 31 |
+
```typescript
|
| 32 |
+
import { useState } from 'react';
|
| 33 |
+
|
| 34 |
+
interface GenerateMeshParams {
|
| 35 |
+
doRemoveBackground?: boolean;
|
| 36 |
+
foregroundRatio?: number;
|
| 37 |
+
mcResolution?: number;
|
| 38 |
+
format?: 'obj' | 'glb';
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
interface UseTripoSRReturn {
|
| 42 |
+
generateMesh: (imageFile: File, params?: GenerateMeshParams) => Promise<Blob | null>;
|
| 43 |
+
generateMeshBase64: (imageFile: File, params?: GenerateMeshParams) => Promise<string | null>;
|
| 44 |
+
loading: boolean;
|
| 45 |
+
error: string | null;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
const API_URL = process.env.REACT_APP_TRIPOSR_API_URL || 'http://localhost:8000';
|
| 49 |
+
|
| 50 |
+
export const useTripoSR = (): UseTripoSRReturn => {
|
| 51 |
+
const [loading, setLoading] = useState(false);
|
| 52 |
+
const [error, setError] = useState<string | null>(null);
|
| 53 |
+
|
| 54 |
+
const generateMesh = async (
|
| 55 |
+
imageFile: File,
|
| 56 |
+
params: GenerateMeshParams = {}
|
| 57 |
+
): Promise<Blob | null> => {
|
| 58 |
+
setLoading(true);
|
| 59 |
+
setError(null);
|
| 60 |
+
|
| 61 |
+
try {
|
| 62 |
+
const formData = new FormData();
|
| 63 |
+
formData.append('image', imageFile);
|
| 64 |
+
formData.append('do_remove_background', String(params.doRemoveBackground ?? true));
|
| 65 |
+
formData.append('foreground_ratio', String(params.foregroundRatio ?? 0.85));
|
| 66 |
+
formData.append('mc_resolution', String(params.mcResolution ?? 256));
|
| 67 |
+
formData.append('format', params.format ?? 'obj');
|
| 68 |
+
|
| 69 |
+
const response = await fetch(`${API_URL}/generate`, {
|
| 70 |
+
method: 'POST',
|
| 71 |
+
body: formData,
|
| 72 |
+
});
|
| 73 |
+
|
| 74 |
+
if (!response.ok) {
|
| 75 |
+
const errorData = await response.json().catch(() => ({ detail: 'Unknown error' }));
|
| 76 |
+
throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
const blob = await response.blob();
|
| 80 |
+
return blob;
|
| 81 |
+
} catch (err) {
|
| 82 |
+
const errorMessage = err instanceof Error ? err.message : 'Failed to generate mesh';
|
| 83 |
+
setError(errorMessage);
|
| 84 |
+
console.error('TripoSR API error:', err);
|
| 85 |
+
return null;
|
| 86 |
+
} finally {
|
| 87 |
+
setLoading(false);
|
| 88 |
+
}
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
const generateMeshBase64 = async (
|
| 92 |
+
imageFile: File,
|
| 93 |
+
params: GenerateMeshParams = {}
|
| 94 |
+
): Promise<string | null> => {
|
| 95 |
+
setLoading(true);
|
| 96 |
+
setError(null);
|
| 97 |
+
|
| 98 |
+
try {
|
| 99 |
+
const formData = new FormData();
|
| 100 |
+
formData.append('image', imageFile);
|
| 101 |
+
formData.append('do_remove_background', String(params.doRemoveBackground ?? true));
|
| 102 |
+
formData.append('foreground_ratio', String(params.foregroundRatio ?? 0.85));
|
| 103 |
+
formData.append('mc_resolution', String(params.mcResolution ?? 256));
|
| 104 |
+
formData.append('format', params.format ?? 'obj');
|
| 105 |
+
|
| 106 |
+
const response = await fetch(`${API_URL}/generate-base64`, {
|
| 107 |
+
method: 'POST',
|
| 108 |
+
body: formData,
|
| 109 |
+
});
|
| 110 |
+
|
| 111 |
+
if (!response.ok) {
|
| 112 |
+
const errorData = await response.json().catch(() => ({ detail: 'Unknown error' }));
|
| 113 |
+
throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
const data = await response.json();
|
| 117 |
+
return data.mesh; // Base64 encoded mesh
|
| 118 |
+
} catch (err) {
|
| 119 |
+
const errorMessage = err instanceof Error ? err.message : 'Failed to generate mesh';
|
| 120 |
+
setError(errorMessage);
|
| 121 |
+
console.error('TripoSR API error:', err);
|
| 122 |
+
return null;
|
| 123 |
+
} finally {
|
| 124 |
+
setLoading(false);
|
| 125 |
+
}
|
| 126 |
+
};
|
| 127 |
+
|
| 128 |
+
return { generateMesh, generateMeshBase64, loading, error };
|
| 129 |
+
};
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### JavaScript Version:
|
| 133 |
+
|
| 134 |
+
```javascript
|
| 135 |
+
import { useState } from 'react';
|
| 136 |
+
|
| 137 |
+
const API_URL = process.env.REACT_APP_TRIPOSR_API_URL || 'http://localhost:8000';
|
| 138 |
+
|
| 139 |
+
export const useTripoSR = () => {
|
| 140 |
+
const [loading, setLoading] = useState(false);
|
| 141 |
+
const [error, setError] = useState(null);
|
| 142 |
+
|
| 143 |
+
const generateMesh = async (imageFile, params = {}) => {
|
| 144 |
+
setLoading(true);
|
| 145 |
+
setError(null);
|
| 146 |
+
|
| 147 |
+
try {
|
| 148 |
+
const formData = new FormData();
|
| 149 |
+
formData.append('image', imageFile);
|
| 150 |
+
formData.append('do_remove_background', String(params.doRemoveBackground ?? true));
|
| 151 |
+
formData.append('foreground_ratio', String(params.foregroundRatio ?? 0.85));
|
| 152 |
+
formData.append('mc_resolution', String(params.mcResolution ?? 256));
|
| 153 |
+
formData.append('format', params.format ?? 'obj');
|
| 154 |
+
|
| 155 |
+
const response = await fetch(`${API_URL}/generate`, {
|
| 156 |
+
method: 'POST',
|
| 157 |
+
body: formData,
|
| 158 |
+
});
|
| 159 |
+
|
| 160 |
+
if (!response.ok) {
|
| 161 |
+
const errorData = await response.json().catch(() => ({ detail: 'Unknown error' }));
|
| 162 |
+
throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
const blob = await response.blob();
|
| 166 |
+
return blob;
|
| 167 |
+
} catch (err) {
|
| 168 |
+
const errorMessage = err.message || 'Failed to generate mesh';
|
| 169 |
+
setError(errorMessage);
|
| 170 |
+
console.error('TripoSR API error:', err);
|
| 171 |
+
return null;
|
| 172 |
+
} finally {
|
| 173 |
+
setLoading(false);
|
| 174 |
+
}
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
const generateMeshBase64 = async (imageFile, params = {}) => {
|
| 178 |
+
setLoading(true);
|
| 179 |
+
setError(null);
|
| 180 |
+
|
| 181 |
+
try {
|
| 182 |
+
const formData = new FormData();
|
| 183 |
+
formData.append('image', imageFile);
|
| 184 |
+
formData.append('do_remove_background', String(params.doRemoveBackground ?? true));
|
| 185 |
+
formData.append('foreground_ratio', String(params.foregroundRatio ?? 0.85));
|
| 186 |
+
formData.append('mc_resolution', String(params.mcResolution ?? 256));
|
| 187 |
+
formData.append('format', params.format ?? 'obj');
|
| 188 |
+
|
| 189 |
+
const response = await fetch(`${API_URL}/generate-base64`, {
|
| 190 |
+
method: 'POST',
|
| 191 |
+
body: formData,
|
| 192 |
+
});
|
| 193 |
+
|
| 194 |
+
if (!response.ok) {
|
| 195 |
+
const errorData = await response.json().catch(() => ({ detail: 'Unknown error' }));
|
| 196 |
+
throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
const data = await response.json();
|
| 200 |
+
return data.mesh;
|
| 201 |
+
} catch (err) {
|
| 202 |
+
const errorMessage = err.message || 'Failed to generate mesh';
|
| 203 |
+
setError(errorMessage);
|
| 204 |
+
console.error('TripoSR API error:', err);
|
| 205 |
+
return null;
|
| 206 |
+
} finally {
|
| 207 |
+
setLoading(false);
|
| 208 |
+
}
|
| 209 |
+
};
|
| 210 |
+
|
| 211 |
+
return { generateMesh, generateMeshBase64, loading, error };
|
| 212 |
+
};
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
## 3. React Component Example
|
| 216 |
+
|
| 217 |
+
Create `src/components/MeshGenerator.tsx` or `.jsx`:
|
| 218 |
+
|
| 219 |
+
```tsx
|
| 220 |
+
import React, { useState } from 'react';
|
| 221 |
+
import { useTripoSR } from '../hooks/useTripoSR';
|
| 222 |
+
|
| 223 |
+
const MeshGenerator: React.FC = () => {
|
| 224 |
+
const [selectedFile, setSelectedFile] = useState<File | null>(null);
|
| 225 |
+
const [preview, setPreview] = useState<string | null>(null);
|
| 226 |
+
const [meshUrl, setMeshUrl] = useState<string | null>(null);
|
| 227 |
+
const { generateMesh, loading, error } = useTripoSR();
|
| 228 |
+
|
| 229 |
+
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
| 230 |
+
const file = e.target.files?.[0];
|
| 231 |
+
if (file) {
|
| 232 |
+
setSelectedFile(file);
|
| 233 |
+
const reader = new FileReader();
|
| 234 |
+
reader.onloadend = () => {
|
| 235 |
+
setPreview(reader.result as string);
|
| 236 |
+
};
|
| 237 |
+
reader.readAsDataURL(file);
|
| 238 |
+
}
|
| 239 |
+
};
|
| 240 |
+
|
| 241 |
+
const handleGenerate = async () => {
|
| 242 |
+
if (!selectedFile) return;
|
| 243 |
+
|
| 244 |
+
const blob = await generateMesh(selectedFile, {
|
| 245 |
+
doRemoveBackground: true,
|
| 246 |
+
foregroundRatio: 0.85,
|
| 247 |
+
mcResolution: 256,
|
| 248 |
+
format: 'obj',
|
| 249 |
+
});
|
| 250 |
+
|
| 251 |
+
if (blob) {
|
| 252 |
+
const url = window.URL.createObjectURL(blob);
|
| 253 |
+
setMeshUrl(url);
|
| 254 |
+
}
|
| 255 |
+
};
|
| 256 |
+
|
| 257 |
+
const handleDownload = () => {
|
| 258 |
+
if (meshUrl) {
|
| 259 |
+
const a = document.createElement('a');
|
| 260 |
+
a.href = meshUrl;
|
| 261 |
+
a.download = 'mesh.obj';
|
| 262 |
+
document.body.appendChild(a);
|
| 263 |
+
a.click();
|
| 264 |
+
document.body.removeChild(a);
|
| 265 |
+
}
|
| 266 |
+
};
|
| 267 |
+
|
| 268 |
+
return (
|
| 269 |
+
<div className="mesh-generator">
|
| 270 |
+
<h2>3D Mesh Generator</h2>
|
| 271 |
+
|
| 272 |
+
<div className="upload-section">
|
| 273 |
+
<input
|
| 274 |
+
type="file"
|
| 275 |
+
accept="image/*"
|
| 276 |
+
onChange={handleFileChange}
|
| 277 |
+
disabled={loading}
|
| 278 |
+
/>
|
| 279 |
+
{preview && (
|
| 280 |
+
<img src={preview} alt="Preview" style={{ maxWidth: '300px', marginTop: '10px' }} />
|
| 281 |
+
)}
|
| 282 |
+
</div>
|
| 283 |
+
|
| 284 |
+
<button onClick={handleGenerate} disabled={!selectedFile || loading}>
|
| 285 |
+
{loading ? 'Generating...' : 'Generate Mesh'}
|
| 286 |
+
</button>
|
| 287 |
+
|
| 288 |
+
{error && <div className="error">Error: {error}</div>}
|
| 289 |
+
|
| 290 |
+
{meshUrl && (
|
| 291 |
+
<div className="result-section">
|
| 292 |
+
<p>Mesh generated successfully!</p>
|
| 293 |
+
<button onClick={handleDownload}>Download Mesh</button>
|
| 294 |
+
{/* You can also use a 3D viewer here */}
|
| 295 |
+
</div>
|
| 296 |
+
)}
|
| 297 |
+
</div>
|
| 298 |
+
);
|
| 299 |
+
};
|
| 300 |
+
|
| 301 |
+
export default MeshGenerator;
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
## 4. Integration with Supabase Storage
|
| 305 |
+
|
| 306 |
+
If you want to store mesh files in Supabase:
|
| 307 |
+
|
| 308 |
+
```typescript
|
| 309 |
+
import { useTripoSR } from '../hooks/useTripoSR';
|
| 310 |
+
import { supabase } from '../lib/supabase'; // Your Supabase client
|
| 311 |
+
|
| 312 |
+
const MeshGeneratorWithSupabase: React.FC = () => {
|
| 313 |
+
const { generateMesh, loading, error } = useTripoSR();
|
| 314 |
+
const [uploading, setUploading] = useState(false);
|
| 315 |
+
|
| 316 |
+
const handleGenerateAndUpload = async (imageFile: File, userId: string) => {
|
| 317 |
+
// Generate mesh
|
| 318 |
+
const blob = await generateMesh(imageFile);
|
| 319 |
+
if (!blob) return;
|
| 320 |
+
|
| 321 |
+
// Upload to Supabase Storage
|
| 322 |
+
setUploading(true);
|
| 323 |
+
try {
|
| 324 |
+
const fileName = `${userId}/${Date.now()}_mesh.obj`;
|
| 325 |
+
const { data, error: uploadError } = await supabase.storage
|
| 326 |
+
.from('meshes') // Your storage bucket name
|
| 327 |
+
.upload(fileName, blob, {
|
| 328 |
+
contentType: 'application/octet-stream',
|
| 329 |
+
upsert: false,
|
| 330 |
+
});
|
| 331 |
+
|
| 332 |
+
if (uploadError) throw uploadError;
|
| 333 |
+
|
| 334 |
+
// Get public URL
|
| 335 |
+
const { data: urlData } = supabase.storage
|
| 336 |
+
.from('meshes')
|
| 337 |
+
.getPublicUrl(fileName);
|
| 338 |
+
|
| 339 |
+
console.log('Mesh uploaded:', urlData.publicUrl);
|
| 340 |
+
return urlData.publicUrl;
|
| 341 |
+
} catch (err) {
|
| 342 |
+
console.error('Upload error:', err);
|
| 343 |
+
} finally {
|
| 344 |
+
setUploading(false);
|
| 345 |
+
}
|
| 346 |
+
};
|
| 347 |
+
|
| 348 |
+
// ... rest of component
|
| 349 |
+
};
|
| 350 |
+
```
|
| 351 |
+
|
| 352 |
+
## 5. Environment Variables
|
| 353 |
+
|
| 354 |
+
Create `.env.local` in your React project:
|
| 355 |
+
|
| 356 |
+
```env
|
| 357 |
+
REACT_APP_TRIPOSR_API_URL=http://localhost:8000
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
For production, update to your deployed API URL.
|
| 361 |
+
|
| 362 |
+
## 6. 3D Mesh Viewer Integration
|
| 363 |
+
|
| 364 |
+
To display the mesh in your React app, you can use libraries like:
|
| 365 |
+
|
| 366 |
+
- **react-three-fiber** + **drei**
|
| 367 |
+
- **@react-three/viewer**
|
| 368 |
+
- **model-viewer** (web component)
|
| 369 |
+
|
| 370 |
+
Example with `model-viewer`:
|
| 371 |
+
|
| 372 |
+
```tsx
|
| 373 |
+
import '@google/model-viewer';
|
| 374 |
+
|
| 375 |
+
const MeshViewer: React.FC<{ meshUrl: string }> = ({ meshUrl }) => {
|
| 376 |
+
return (
|
| 377 |
+
<model-viewer
|
| 378 |
+
src={meshUrl}
|
| 379 |
+
alt="3D Mesh"
|
| 380 |
+
auto-rotate
|
| 381 |
+
camera-controls
|
| 382 |
+
style={{ width: '100%', height: '500px' }}
|
| 383 |
+
/>
|
| 384 |
+
);
|
| 385 |
+
};
|
| 386 |
+
```
|
| 387 |
+
|
| 388 |
+
## 7. Complete Example with All Features
|
| 389 |
+
|
| 390 |
+
```tsx
|
| 391 |
+
import React, { useState } from 'react';
|
| 392 |
+
import { useTripoSR } from '../hooks/useTripoSR';
|
| 393 |
+
import { supabase } from '../lib/supabase';
|
| 394 |
+
|
| 395 |
+
const CompleteMeshGenerator: React.FC = () => {
|
| 396 |
+
const [file, setFile] = useState<File | null>(null);
|
| 397 |
+
const [preview, setPreview] = useState<string | null>(null);
|
| 398 |
+
const [meshUrl, setMeshUrl] = useState<string | null>(null);
|
| 399 |
+
const [foregroundRatio, setForegroundRatio] = useState(0.85);
|
| 400 |
+
const [resolution, setResolution] = useState(256);
|
| 401 |
+
const [format, setFormat] = useState<'obj' | 'glb'>('obj');
|
| 402 |
+
|
| 403 |
+
const { generateMesh, loading, error } = useTripoSR();
|
| 404 |
+
|
| 405 |
+
const handleFileSelect = (e: React.ChangeEvent<HTMLInputElement>) => {
|
| 406 |
+
const selected = e.target.files?.[0];
|
| 407 |
+
if (selected) {
|
| 408 |
+
setFile(selected);
|
| 409 |
+
const reader = new FileReader();
|
| 410 |
+
reader.onload = () => setPreview(reader.result as string);
|
| 411 |
+
reader.readAsDataURL(selected);
|
| 412 |
+
}
|
| 413 |
+
};
|
| 414 |
+
|
| 415 |
+
const handleGenerate = async () => {
|
| 416 |
+
if (!file) return;
|
| 417 |
+
|
| 418 |
+
const blob = await generateMesh(file, {
|
| 419 |
+
doRemoveBackground: true,
|
| 420 |
+
foregroundRatio,
|
| 421 |
+
mcResolution: resolution,
|
| 422 |
+
format,
|
| 423 |
+
});
|
| 424 |
+
|
| 425 |
+
if (blob) {
|
| 426 |
+
const url = window.URL.createObjectURL(blob);
|
| 427 |
+
setMeshUrl(url);
|
| 428 |
+
}
|
| 429 |
+
};
|
| 430 |
+
|
| 431 |
+
const handleSaveToSupabase = async () => {
|
| 432 |
+
if (!meshUrl || !file) return;
|
| 433 |
+
|
| 434 |
+
const { data: { user } } = await supabase.auth.getUser();
|
| 435 |
+
if (!user) {
|
| 436 |
+
alert('Please log in to save meshes');
|
| 437 |
+
return;
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
const response = await fetch(meshUrl);
|
| 441 |
+
const blob = await response.blob();
|
| 442 |
+
const fileName = `${user.id}/${Date.now()}_mesh.${format}`;
|
| 443 |
+
|
| 444 |
+
const { error } = await supabase.storage
|
| 445 |
+
.from('meshes')
|
| 446 |
+
.upload(fileName, blob);
|
| 447 |
+
|
| 448 |
+
if (error) {
|
| 449 |
+
console.error('Upload error:', error);
|
| 450 |
+
} else {
|
| 451 |
+
alert('Mesh saved to Supabase!');
|
| 452 |
+
}
|
| 453 |
+
};
|
| 454 |
+
|
| 455 |
+
return (
|
| 456 |
+
<div style={{ padding: '20px' }}>
|
| 457 |
+
<h1>3D Mesh Generator</h1>
|
| 458 |
+
|
| 459 |
+
<div>
|
| 460 |
+
<input type="file" accept="image/*" onChange={handleFileSelect} />
|
| 461 |
+
{preview && <img src={preview} alt="Preview" style={{ maxWidth: '300px' }} />}
|
| 462 |
+
</div>
|
| 463 |
+
|
| 464 |
+
<div style={{ margin: '20px 0' }}>
|
| 465 |
+
<label>
|
| 466 |
+
Foreground Ratio: {foregroundRatio}
|
| 467 |
+
<input
|
| 468 |
+
type="range"
|
| 469 |
+
min="0.5"
|
| 470 |
+
max="1.0"
|
| 471 |
+
step="0.05"
|
| 472 |
+
value={foregroundRatio}
|
| 473 |
+
onChange={(e) => setForegroundRatio(parseFloat(e.target.value))}
|
| 474 |
+
/>
|
| 475 |
+
</label>
|
| 476 |
+
</div>
|
| 477 |
+
|
| 478 |
+
<div style={{ margin: '20px 0' }}>
|
| 479 |
+
<label>
|
| 480 |
+
Resolution: {resolution}
|
| 481 |
+
<input
|
| 482 |
+
type="range"
|
| 483 |
+
min="128"
|
| 484 |
+
max="320"
|
| 485 |
+
step="32"
|
| 486 |
+
value={resolution}
|
| 487 |
+
onChange={(e) => setResolution(parseInt(e.target.value))}
|
| 488 |
+
/>
|
| 489 |
+
</label>
|
| 490 |
+
</div>
|
| 491 |
+
|
| 492 |
+
<div style={{ margin: '20px 0' }}>
|
| 493 |
+
<label>
|
| 494 |
+
Format:
|
| 495 |
+
<select value={format} onChange={(e) => setFormat(e.target.value as 'obj' | 'glb')}>
|
| 496 |
+
<option value="obj">OBJ</option>
|
| 497 |
+
<option value="glb">GLB</option>
|
| 498 |
+
</select>
|
| 499 |
+
</label>
|
| 500 |
+
</div>
|
| 501 |
+
|
| 502 |
+
<button onClick={handleGenerate} disabled={!file || loading}>
|
| 503 |
+
{loading ? 'Generating...' : 'Generate Mesh'}
|
| 504 |
+
</button>
|
| 505 |
+
|
| 506 |
+
{error && <div style={{ color: 'red' }}>Error: {error}</div>}
|
| 507 |
+
|
| 508 |
+
{meshUrl && (
|
| 509 |
+
<div>
|
| 510 |
+
<p>Mesh generated!</p>
|
| 511 |
+
<a href={meshUrl} download={`mesh.${format}`}>
|
| 512 |
+
<button>Download</button>
|
| 513 |
+
</a>
|
| 514 |
+
<button onClick={handleSaveToSupabase}>Save to Supabase</button>
|
| 515 |
+
</div>
|
| 516 |
+
)}
|
| 517 |
+
</div>
|
| 518 |
+
);
|
| 519 |
+
};
|
| 520 |
+
|
| 521 |
+
export default CompleteMeshGenerator;
|
| 522 |
+
```
|
| 523 |
+
|
| 524 |
+
## 8. API Health Check
|
| 525 |
+
|
| 526 |
+
Add a health check on app load:
|
| 527 |
+
|
| 528 |
+
```typescript
|
| 529 |
+
useEffect(() => {
|
| 530 |
+
const checkAPI = async () => {
|
| 531 |
+
try {
|
| 532 |
+
const response = await fetch(`${API_URL}/health`);
|
| 533 |
+
const data = await response.json();
|
| 534 |
+
console.log('TripoSR API status:', data);
|
| 535 |
+
} catch (err) {
|
| 536 |
+
console.error('TripoSR API is not available:', err);
|
| 537 |
+
}
|
| 538 |
+
};
|
| 539 |
+
checkAPI();
|
| 540 |
+
}, []);
|
| 541 |
+
```
|
| 542 |
+
|
| 543 |
+
## Notes
|
| 544 |
+
|
| 545 |
+
- Make sure your TripoSR API server is running before using the React app
|
| 546 |
+
- The API takes 30-60 seconds to generate a mesh, so show appropriate loading states
|
| 547 |
+
- Consider implementing request cancellation for better UX
|
| 548 |
+
- For production, deploy the API server and update the `REACT_APP_TRIPOSR_API_URL` environment variable
|
| 549 |
+
|
README.md
CHANGED
|
@@ -1,11 +1,143 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
license: mit
|
| 9 |
-
---
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: TripoSR API
|
| 3 |
+
emoji: 🎨
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# TripoSR API 🎨
|
| 12 |
+
|
| 13 |
+
Fast 3D reconstruction from a single image using TripoSR.
|
| 14 |
+
|
| 15 |
+
## 🚀 Features
|
| 16 |
+
|
| 17 |
+
- **Fast 3D Generation**: Convert images to 3D models in seconds
|
| 18 |
+
- **Multiple Formats**: Export as OBJ or GLB
|
| 19 |
+
- **Texture Baking**: Optional texture atlas generation
|
| 20 |
+
- **Background Removal**: Automatic background removal
|
| 21 |
+
- **REST API**: Easy integration with any frontend
|
| 22 |
+
|
| 23 |
+
## 📡 API Endpoints
|
| 24 |
+
|
| 25 |
+
### Health Check
|
| 26 |
+
```bash
|
| 27 |
+
GET /health
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### Generate 3D Model
|
| 31 |
+
```bash
|
| 32 |
+
POST /generate
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
**Parameters:**
|
| 36 |
+
- `image`: Image file (PNG, JPG, JPEG)
|
| 37 |
+
- `do_remove_background`: Remove background (default: true)
|
| 38 |
+
- `foreground_ratio`: Foreground size ratio (default: 0.85)
|
| 39 |
+
- `mc_resolution`: Mesh resolution (default: 256)
|
| 40 |
+
- `format`: Output format - "obj" or "glb" (default: "obj")
|
| 41 |
+
- `bake_texture_flag`: Bake texture (default: true)
|
| 42 |
+
- `texture_resolution`: Texture resolution (default: 2048)
|
| 43 |
+
- `orientation`: Mesh orientation - "standard", "gradio", or "none" (default: "standard")
|
| 44 |
+
|
| 45 |
+
**Returns:**
|
| 46 |
+
- ZIP file containing mesh and texture (if texture baking enabled)
|
| 47 |
+
- Or mesh file only (if texture baking disabled)
|
| 48 |
+
|
| 49 |
+
### Generate 3D Model (Base64)
|
| 50 |
+
```bash
|
| 51 |
+
POST /generate-base64
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Same parameters as `/generate`, but returns JSON with base64-encoded mesh and texture.
|
| 55 |
+
|
| 56 |
+
## 🧪 Example Usage
|
| 57 |
+
|
| 58 |
+
### cURL
|
| 59 |
+
```bash
|
| 60 |
+
curl -X POST https://YOUR-SPACE-URL/generate \
|
| 61 |
+
-F "image=@your_image.png" \
|
| 62 |
+
-F "format=obj" \
|
| 63 |
+
-F "bake_texture_flag=true" \
|
| 64 |
+
-o output.zip
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Python
|
| 68 |
+
```python
|
| 69 |
+
import requests
|
| 70 |
+
|
| 71 |
+
url = "https://YOUR-SPACE-URL/generate"
|
| 72 |
+
files = {"image": open("your_image.png", "rb")}
|
| 73 |
+
data = {
|
| 74 |
+
"format": "obj",
|
| 75 |
+
"bake_texture_flag": True,
|
| 76 |
+
"mc_resolution": 256
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
response = requests.post(url, files=files, data=data)
|
| 80 |
+
with open("output.zip", "wb") as f:
|
| 81 |
+
f.write(response.content)
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### JavaScript
|
| 85 |
+
```javascript
|
| 86 |
+
const formData = new FormData();
|
| 87 |
+
formData.append('image', fileInput.files[0]);
|
| 88 |
+
formData.append('format', 'obj');
|
| 89 |
+
formData.append('bake_texture_flag', 'true');
|
| 90 |
+
|
| 91 |
+
const response = await fetch('https://YOUR-SPACE-URL/generate', {
|
| 92 |
+
method: 'POST',
|
| 93 |
+
body: formData
|
| 94 |
+
});
|
| 95 |
+
|
| 96 |
+
const blob = await response.blob();
|
| 97 |
+
// Download or process the blob
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## ⚙️ Configuration
|
| 101 |
+
|
| 102 |
+
### GPU Support
|
| 103 |
+
This Space can run on CPU or GPU:
|
| 104 |
+
- **CPU**: Slower but free
|
| 105 |
+
- **GPU (T4)**: Much faster, may require upgrade in Space settings
|
| 106 |
+
|
| 107 |
+
To upgrade to GPU:
|
| 108 |
+
1. Go to Space settings
|
| 109 |
+
2. Select "T4 small" under Hardware
|
| 110 |
+
3. Space will rebuild automatically
|
| 111 |
+
|
| 112 |
+
### Performance
|
| 113 |
+
- **CPU**: ~30-60 seconds per image
|
| 114 |
+
- **GPU (T4)**: ~5-10 seconds per image
|
| 115 |
+
- **GPU (A100)**: ~1-2 seconds per image
|
| 116 |
+
|
| 117 |
+
## 📚 Documentation
|
| 118 |
+
|
| 119 |
+
- [TripoSR Paper](https://arxiv.org/abs/2403.02151)
|
| 120 |
+
- [API Documentation](API_README.md)
|
| 121 |
+
- [React Integration Guide](REACT_INTEGRATION.md)
|
| 122 |
+
|
| 123 |
+
## 🔗 Links
|
| 124 |
+
|
| 125 |
+
- [GitHub Repository](https://github.com/Ahmedbelaid/TripoSR-api)
|
| 126 |
+
- [Original TripoSR](https://github.com/VAST-AI-Research/TripoSR)
|
| 127 |
+
- [Stability AI](https://stability.ai/)
|
| 128 |
+
- [Tripo AI](https://www.tripo3d.ai/)
|
| 129 |
+
|
| 130 |
+
## 📄 License
|
| 131 |
+
|
| 132 |
+
MIT License - see [LICENSE](LICENSE) file for details.
|
| 133 |
+
|
| 134 |
+
## 🙏 Credits
|
| 135 |
+
|
| 136 |
+
- **TripoSR Model**: [Stability AI](https://stability.ai/) and [Tripo AI](https://www.tripo3d.ai/)
|
| 137 |
+
- **API Implementation**: Community contribution
|
| 138 |
+
|
| 139 |
+
## 🆘 Support
|
| 140 |
+
|
| 141 |
+
For issues or questions:
|
| 142 |
+
- Open an issue on [GitHub](https://github.com/Ahmedbelaid/TripoSR-api/issues)
|
| 143 |
+
- Join the [Discord](https://discord.gg/mvS9mCfMnQ)
|
README_DOCKER.md
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docker Setup for TripoSR API
|
| 2 |
+
|
| 3 |
+
This guide explains how to build and run the TripoSR API using Docker.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- Docker installed on your system
|
| 8 |
+
- (Optional) NVIDIA Docker runtime (nvidia-docker2) for GPU support
|
| 9 |
+
- At least 8GB of available RAM
|
| 10 |
+
- (For GPU) NVIDIA GPU with CUDA support
|
| 11 |
+
|
| 12 |
+
**Note:** This Dockerfile uses the `-devel` PyTorch image (instead of `-runtime`) because `torchmcubes` requires the CUDA development toolkit to compile with CUDA support. This results in a larger image size (~8-10GB) but is necessary for proper compilation.
|
| 13 |
+
|
| 14 |
+
## Building the Docker Image
|
| 15 |
+
|
| 16 |
+
### Basic Build
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
docker build -t triposr-api:latest .
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### Build with Specific Tag
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
docker build -t triposr-api:v1.0 .
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Running the Container
|
| 29 |
+
|
| 30 |
+
### CPU Mode
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
docker run -d \
|
| 34 |
+
--name triposr-api \
|
| 35 |
+
-p 8000:8000 \
|
| 36 |
+
triposr-api:latest
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### GPU Mode (NVIDIA)
|
| 40 |
+
|
| 41 |
+
First, ensure you have `nvidia-docker2` installed. Then run:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
docker run -d \
|
| 45 |
+
--name triposr-api \
|
| 46 |
+
--gpus all \
|
| 47 |
+
-p 8000:8000 \
|
| 48 |
+
-e CUDA_VISIBLE_DEVICES=0 \
|
| 49 |
+
triposr-api:latest
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Using Docker Compose
|
| 53 |
+
|
| 54 |
+
For easier management, use Docker Compose:
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
# Start the service
|
| 58 |
+
docker-compose up -d
|
| 59 |
+
|
| 60 |
+
# View logs
|
| 61 |
+
docker-compose logs -f
|
| 62 |
+
|
| 63 |
+
# Stop the service
|
| 64 |
+
docker-compose down
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
For GPU support with Docker Compose, uncomment the GPU-related lines in `docker-compose.yml`.
|
| 68 |
+
|
| 69 |
+
## Verifying the Installation
|
| 70 |
+
|
| 71 |
+
Once the container is running, check the health endpoint:
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
curl http://localhost:8000/health
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
You should see a response like:
|
| 78 |
+
```json
|
| 79 |
+
{
|
| 80 |
+
"status": "healthy",
|
| 81 |
+
"device": "cuda:0",
|
| 82 |
+
"cuda_available": true
|
| 83 |
+
}
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## API Usage
|
| 87 |
+
|
| 88 |
+
The API will be available at `http://localhost:8000`. See `API_README.md` for detailed API documentation.
|
| 89 |
+
|
| 90 |
+
### Example Request
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
curl -X POST "http://localhost:8000/generate" \
|
| 94 |
+
-F "image=@your_image.jpg" \
|
| 95 |
+
-F "orientation=standard" \
|
| 96 |
+
-F "format=obj" \
|
| 97 |
+
--output mesh.zip
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## Troubleshooting
|
| 101 |
+
|
| 102 |
+
### Container Fails to Start
|
| 103 |
+
|
| 104 |
+
1. Check logs: `docker logs triposr-api`
|
| 105 |
+
2. Ensure port 8000 is not already in use
|
| 106 |
+
3. Verify you have enough memory (at least 8GB recommended)
|
| 107 |
+
|
| 108 |
+
### CUDA/GPU Issues
|
| 109 |
+
|
| 110 |
+
1. Verify NVIDIA Docker runtime: `docker run --rm --gpus all nvidia/cuda:11.7.0-base-ubuntu20.04 nvidia-smi`
|
| 111 |
+
2. Check CUDA availability in container: `docker exec triposr-api python -c "import torch; print(torch.cuda.is_available())"`
|
| 112 |
+
|
| 113 |
+
### Out of Memory
|
| 114 |
+
|
| 115 |
+
If you encounter OOM errors:
|
| 116 |
+
- Reduce `mc_resolution` parameter (default: 256, try 128 or 64)
|
| 117 |
+
- Reduce `texture_resolution` parameter (default: 2048, try 1024)
|
| 118 |
+
- Use CPU mode if GPU memory is limited
|
| 119 |
+
|
| 120 |
+
### Build Errors
|
| 121 |
+
|
| 122 |
+
If the build fails:
|
| 123 |
+
- Ensure you have a stable internet connection (model will be downloaded)
|
| 124 |
+
- Check that all dependencies in `requirements.txt` are valid
|
| 125 |
+
- Try building with `--no-cache`: `docker build --no-cache -t triposr-api:latest .`
|
| 126 |
+
|
| 127 |
+
## Environment Variables
|
| 128 |
+
|
| 129 |
+
- `CUDA_VISIBLE_DEVICES`: Set to specific GPU ID (e.g., "0") or empty string for CPU mode
|
| 130 |
+
|
| 131 |
+
## Volumes
|
| 132 |
+
|
| 133 |
+
The container can mount volumes for:
|
| 134 |
+
- `/app/output`: Output directory for generated meshes (optional)
|
| 135 |
+
|
| 136 |
+
Example:
|
| 137 |
+
```bash
|
| 138 |
+
docker run -d \
|
| 139 |
+
--name triposr-api \
|
| 140 |
+
-p 8000:8000 \
|
| 141 |
+
-v $(pwd)/output:/app/output \
|
| 142 |
+
triposr-api:latest
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
## Stopping and Removing
|
| 146 |
+
|
| 147 |
+
```bash
|
| 148 |
+
# Stop the container
|
| 149 |
+
docker stop triposr-api
|
| 150 |
+
|
| 151 |
+
# Remove the container
|
| 152 |
+
docker rm triposr-api
|
| 153 |
+
|
| 154 |
+
# Remove the image
|
| 155 |
+
docker rmi triposr-api:latest
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
|
README_HF_SPACES.md
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: TripoSR API
|
| 3 |
+
emoji: 🎨
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# TripoSR API 🎨
|
| 12 |
+
|
| 13 |
+
Fast 3D reconstruction from a single image using TripoSR.
|
| 14 |
+
|
| 15 |
+
## 🚀 Features
|
| 16 |
+
|
| 17 |
+
- **Fast 3D Generation**: Convert images to 3D models in seconds
|
| 18 |
+
- **Multiple Formats**: Export as OBJ or GLB
|
| 19 |
+
- **Texture Baking**: Optional texture atlas generation
|
| 20 |
+
- **Background Removal**: Automatic background removal
|
| 21 |
+
- **REST API**: Easy integration with any frontend
|
| 22 |
+
|
| 23 |
+
## 📡 API Endpoints
|
| 24 |
+
|
| 25 |
+
### Health Check
|
| 26 |
+
```bash
|
| 27 |
+
GET /health
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### Generate 3D Model
|
| 31 |
+
```bash
|
| 32 |
+
POST /generate
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
**Parameters:**
|
| 36 |
+
- `image`: Image file (PNG, JPG, JPEG)
|
| 37 |
+
- `do_remove_background`: Remove background (default: true)
|
| 38 |
+
- `foreground_ratio`: Foreground size ratio (default: 0.85)
|
| 39 |
+
- `mc_resolution`: Mesh resolution (default: 256)
|
| 40 |
+
- `format`: Output format - "obj" or "glb" (default: "obj")
|
| 41 |
+
- `bake_texture_flag`: Bake texture (default: true)
|
| 42 |
+
- `texture_resolution`: Texture resolution (default: 2048)
|
| 43 |
+
- `orientation`: Mesh orientation - "standard", "gradio", or "none" (default: "standard")
|
| 44 |
+
|
| 45 |
+
**Returns:**
|
| 46 |
+
- ZIP file containing mesh and texture (if texture baking enabled)
|
| 47 |
+
- Or mesh file only (if texture baking disabled)
|
| 48 |
+
|
| 49 |
+
### Generate 3D Model (Base64)
|
| 50 |
+
```bash
|
| 51 |
+
POST /generate-base64
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Same parameters as `/generate`, but returns JSON with base64-encoded mesh and texture.
|
| 55 |
+
|
| 56 |
+
## 🧪 Example Usage
|
| 57 |
+
|
| 58 |
+
### cURL
|
| 59 |
+
```bash
|
| 60 |
+
curl -X POST https://YOUR-SPACE-URL/generate \
|
| 61 |
+
-F "image=@your_image.png" \
|
| 62 |
+
-F "format=obj" \
|
| 63 |
+
-F "bake_texture_flag=true" \
|
| 64 |
+
-o output.zip
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Python
|
| 68 |
+
```python
|
| 69 |
+
import requests
|
| 70 |
+
|
| 71 |
+
url = "https://YOUR-SPACE-URL/generate"
|
| 72 |
+
files = {"image": open("your_image.png", "rb")}
|
| 73 |
+
data = {
|
| 74 |
+
"format": "obj",
|
| 75 |
+
"bake_texture_flag": True,
|
| 76 |
+
"mc_resolution": 256
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
response = requests.post(url, files=files, data=data)
|
| 80 |
+
with open("output.zip", "wb") as f:
|
| 81 |
+
f.write(response.content)
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### JavaScript
|
| 85 |
+
```javascript
|
| 86 |
+
const formData = new FormData();
|
| 87 |
+
formData.append('image', fileInput.files[0]);
|
| 88 |
+
formData.append('format', 'obj');
|
| 89 |
+
formData.append('bake_texture_flag', 'true');
|
| 90 |
+
|
| 91 |
+
const response = await fetch('https://YOUR-SPACE-URL/generate', {
|
| 92 |
+
method: 'POST',
|
| 93 |
+
body: formData
|
| 94 |
+
});
|
| 95 |
+
|
| 96 |
+
const blob = await response.blob();
|
| 97 |
+
// Download or process the blob
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## ⚙️ Configuration
|
| 101 |
+
|
| 102 |
+
### GPU Support
|
| 103 |
+
This Space can run on CPU or GPU:
|
| 104 |
+
- **CPU**: Slower but free
|
| 105 |
+
- **GPU (T4)**: Much faster, may require upgrade in Space settings
|
| 106 |
+
|
| 107 |
+
To upgrade to GPU:
|
| 108 |
+
1. Go to Space settings
|
| 109 |
+
2. Select "T4 small" under Hardware
|
| 110 |
+
3. Space will rebuild automatically
|
| 111 |
+
|
| 112 |
+
### Performance
|
| 113 |
+
- **CPU**: ~30-60 seconds per image
|
| 114 |
+
- **GPU (T4)**: ~5-10 seconds per image
|
| 115 |
+
- **GPU (A100)**: ~1-2 seconds per image
|
| 116 |
+
|
| 117 |
+
## 📚 Documentation
|
| 118 |
+
|
| 119 |
+
- [TripoSR Paper](https://arxiv.org/abs/2403.02151)
|
| 120 |
+
- [API Documentation](API_README.md)
|
| 121 |
+
- [React Integration Guide](REACT_INTEGRATION.md)
|
| 122 |
+
|
| 123 |
+
## 🔗 Links
|
| 124 |
+
|
| 125 |
+
- [GitHub Repository](https://github.com/Ahmedbelaid/TripoSR-api)
|
| 126 |
+
- [Original TripoSR](https://github.com/VAST-AI-Research/TripoSR)
|
| 127 |
+
- [Stability AI](https://stability.ai/)
|
| 128 |
+
- [Tripo AI](https://www.tripo3d.ai/)
|
| 129 |
+
|
| 130 |
+
## 📄 License
|
| 131 |
+
|
| 132 |
+
MIT License - see [LICENSE](LICENSE) file for details.
|
| 133 |
+
|
| 134 |
+
## 🙏 Credits
|
| 135 |
+
|
| 136 |
+
- **TripoSR Model**: [Stability AI](https://stability.ai/) and [Tripo AI](https://www.tripo3d.ai/)
|
| 137 |
+
- **API Implementation**: Community contribution
|
| 138 |
+
|
| 139 |
+
## 🆘 Support
|
| 140 |
+
|
| 141 |
+
For issues or questions:
|
| 142 |
+
- Open an issue on [GitHub](https://github.com/Ahmedbelaid/TripoSR-api/issues)
|
| 143 |
+
- Join the [Discord](https://discord.gg/mvS9mCfMnQ)
|
api_example.html
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>TripoSR API Example</title>
|
| 7 |
+
<style>
|
| 8 |
+
body {
|
| 9 |
+
font-family: Arial, sans-serif;
|
| 10 |
+
max-width: 800px;
|
| 11 |
+
margin: 50px auto;
|
| 12 |
+
padding: 20px;
|
| 13 |
+
}
|
| 14 |
+
.container {
|
| 15 |
+
border: 1px solid #ddd;
|
| 16 |
+
padding: 20px;
|
| 17 |
+
border-radius: 8px;
|
| 18 |
+
}
|
| 19 |
+
input[type="file"] {
|
| 20 |
+
margin: 10px 0;
|
| 21 |
+
}
|
| 22 |
+
button {
|
| 23 |
+
background-color: #4CAF50;
|
| 24 |
+
color: white;
|
| 25 |
+
padding: 10px 20px;
|
| 26 |
+
border: none;
|
| 27 |
+
border-radius: 4px;
|
| 28 |
+
cursor: pointer;
|
| 29 |
+
font-size: 16px;
|
| 30 |
+
}
|
| 31 |
+
button:hover {
|
| 32 |
+
background-color: #45a049;
|
| 33 |
+
}
|
| 34 |
+
button:disabled {
|
| 35 |
+
background-color: #cccccc;
|
| 36 |
+
cursor: not-allowed;
|
| 37 |
+
}
|
| 38 |
+
.preview {
|
| 39 |
+
margin: 20px 0;
|
| 40 |
+
}
|
| 41 |
+
.preview img {
|
| 42 |
+
max-width: 100%;
|
| 43 |
+
border: 1px solid #ddd;
|
| 44 |
+
border-radius: 4px;
|
| 45 |
+
}
|
| 46 |
+
.result {
|
| 47 |
+
margin-top: 20px;
|
| 48 |
+
padding: 10px;
|
| 49 |
+
background-color: #f0f0f0;
|
| 50 |
+
border-radius: 4px;
|
| 51 |
+
}
|
| 52 |
+
.error {
|
| 53 |
+
color: red;
|
| 54 |
+
margin-top: 10px;
|
| 55 |
+
}
|
| 56 |
+
.loading {
|
| 57 |
+
color: #666;
|
| 58 |
+
font-style: italic;
|
| 59 |
+
}
|
| 60 |
+
</style>
|
| 61 |
+
</head>
|
| 62 |
+
<body>
|
| 63 |
+
<div class="container">
|
| 64 |
+
<h1>TripoSR API Example</h1>
|
| 65 |
+
<p>Upload an image to generate a 3D mesh</p>
|
| 66 |
+
|
| 67 |
+
<form id="meshForm">
|
| 68 |
+
<div>
|
| 69 |
+
<label for="imageInput">Select Image:</label><br>
|
| 70 |
+
<input type="file" id="imageInput" accept="image/*" required>
|
| 71 |
+
</div>
|
| 72 |
+
|
| 73 |
+
<div style="margin: 15px 0;">
|
| 74 |
+
<label>
|
| 75 |
+
<input type="checkbox" id="removeBg" checked>
|
| 76 |
+
Remove Background
|
| 77 |
+
</label>
|
| 78 |
+
</div>
|
| 79 |
+
|
| 80 |
+
<div style="margin: 15px 0;">
|
| 81 |
+
<label for="foregroundRatio">Foreground Ratio: <span id="ratioValue">0.85</span></label><br>
|
| 82 |
+
<input type="range" id="foregroundRatio" min="0.5" max="1.0" step="0.05" value="0.85">
|
| 83 |
+
</div>
|
| 84 |
+
|
| 85 |
+
<div style="margin: 15px 0;">
|
| 86 |
+
<label for="resolution">Resolution: <span id="resValue">256</span></label><br>
|
| 87 |
+
<input type="range" id="resolution" min="128" max="320" step="32" value="256">
|
| 88 |
+
</div>
|
| 89 |
+
|
| 90 |
+
<div style="margin: 15px 0;">
|
| 91 |
+
<label for="format">Format:</label>
|
| 92 |
+
<select id="format">
|
| 93 |
+
<option value="obj">OBJ</option>
|
| 94 |
+
<option value="glb">GLB</option>
|
| 95 |
+
</select>
|
| 96 |
+
</div>
|
| 97 |
+
|
| 98 |
+
<button type="submit" id="generateBtn">Generate Mesh</button>
|
| 99 |
+
</form>
|
| 100 |
+
|
| 101 |
+
<div class="preview" id="preview"></div>
|
| 102 |
+
<div id="result"></div>
|
| 103 |
+
</div>
|
| 104 |
+
|
| 105 |
+
<script>
|
| 106 |
+
const API_URL = 'http://localhost:8000';
|
| 107 |
+
|
| 108 |
+
// Update slider values
|
| 109 |
+
document.getElementById('foregroundRatio').addEventListener('input', (e) => {
|
| 110 |
+
document.getElementById('ratioValue').textContent = e.target.value;
|
| 111 |
+
});
|
| 112 |
+
|
| 113 |
+
document.getElementById('resolution').addEventListener('input', (e) => {
|
| 114 |
+
document.getElementById('resValue').textContent = e.target.value;
|
| 115 |
+
});
|
| 116 |
+
|
| 117 |
+
// Preview image
|
| 118 |
+
document.getElementById('imageInput').addEventListener('change', (e) => {
|
| 119 |
+
const file = e.target.files[0];
|
| 120 |
+
if (file) {
|
| 121 |
+
const reader = new FileReader();
|
| 122 |
+
reader.onload = (e) => {
|
| 123 |
+
const preview = document.getElementById('preview');
|
| 124 |
+
preview.innerHTML = `<img src="${e.target.result}" alt="Preview">`;
|
| 125 |
+
};
|
| 126 |
+
reader.readAsDataURL(file);
|
| 127 |
+
}
|
| 128 |
+
});
|
| 129 |
+
|
| 130 |
+
// Handle form submission
|
| 131 |
+
document.getElementById('meshForm').addEventListener('submit', async (e) => {
|
| 132 |
+
e.preventDefault();
|
| 133 |
+
|
| 134 |
+
const formData = new FormData();
|
| 135 |
+
const imageInput = document.getElementById('imageInput');
|
| 136 |
+
const removeBg = document.getElementById('removeBg').checked;
|
| 137 |
+
const foregroundRatio = parseFloat(document.getElementById('foregroundRatio').value);
|
| 138 |
+
const resolution = parseInt(document.getElementById('resolution').value);
|
| 139 |
+
const format = document.getElementById('format').value;
|
| 140 |
+
|
| 141 |
+
if (!imageInput.files[0]) {
|
| 142 |
+
alert('Please select an image');
|
| 143 |
+
return;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
formData.append('image', imageInput.files[0]);
|
| 147 |
+
formData.append('do_remove_background', removeBg);
|
| 148 |
+
formData.append('foreground_ratio', foregroundRatio);
|
| 149 |
+
formData.append('mc_resolution', resolution);
|
| 150 |
+
formData.append('format', format);
|
| 151 |
+
|
| 152 |
+
const generateBtn = document.getElementById('generateBtn');
|
| 153 |
+
const resultDiv = document.getElementById('result');
|
| 154 |
+
|
| 155 |
+
generateBtn.disabled = true;
|
| 156 |
+
generateBtn.textContent = 'Generating...';
|
| 157 |
+
resultDiv.innerHTML = '<div class="loading">Processing image and generating mesh. This may take 30-60 seconds...</div>';
|
| 158 |
+
|
| 159 |
+
try {
|
| 160 |
+
const response = await fetch(`${API_URL}/generate`, {
|
| 161 |
+
method: 'POST',
|
| 162 |
+
body: formData
|
| 163 |
+
});
|
| 164 |
+
|
| 165 |
+
if (!response.ok) {
|
| 166 |
+
const error = await response.json();
|
| 167 |
+
throw new Error(error.detail || 'Failed to generate mesh');
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
// Get the mesh file
|
| 171 |
+
const blob = await response.blob();
|
| 172 |
+
const url = window.URL.createObjectURL(blob);
|
| 173 |
+
|
| 174 |
+
resultDiv.innerHTML = `
|
| 175 |
+
<div class="result">
|
| 176 |
+
<h3>Success! Mesh generated</h3>
|
| 177 |
+
<p>Format: ${format.toUpperCase()}</p>
|
| 178 |
+
<a href="${url}" download="mesh.${format}">
|
| 179 |
+
<button>Download Mesh</button>
|
| 180 |
+
</a>
|
| 181 |
+
</div>
|
| 182 |
+
`;
|
| 183 |
+
|
| 184 |
+
} catch (error) {
|
| 185 |
+
resultDiv.innerHTML = `<div class="error">Error: ${error.message}</div>`;
|
| 186 |
+
} finally {
|
| 187 |
+
generateBtn.disabled = false;
|
| 188 |
+
generateBtn.textContent = 'Generate Mesh';
|
| 189 |
+
}
|
| 190 |
+
});
|
| 191 |
+
</script>
|
| 192 |
+
</body>
|
| 193 |
+
</html>
|
| 194 |
+
|
api_server.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
import base64
|
| 5 |
+
import zipfile
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import rembg
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
|
| 14 |
+
from fastapi.responses import JSONResponse, FileResponse
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
+
from pydantic import BaseModel
|
| 17 |
+
|
| 18 |
+
from tsr.system import TSR
|
| 19 |
+
from tsr.utils import remove_background, resize_foreground, apply_mesh_orientation
|
| 20 |
+
from tsr.bake_texture import bake_texture
|
| 21 |
+
|
| 22 |
+
# Configure logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Initialize FastAPI app
|
| 28 |
+
app = FastAPI(title="TripoSR API", version="1.0.0")
|
| 29 |
+
|
| 30 |
+
# Enable CORS for frontend
|
| 31 |
+
app.add_middleware(
|
| 32 |
+
CORSMiddleware,
|
| 33 |
+
allow_origins=[
|
| 34 |
+
"http://localhost:3000", # React default port
|
| 35 |
+
"http://localhost:5173", # Vite default port
|
| 36 |
+
"http://localhost:8080", # Vue default port
|
| 37 |
+
"http://127.0.0.1:3000",
|
| 38 |
+
"http://127.0.0.1:5173",
|
| 39 |
+
"http://127.0.0.1:8080",
|
| 40 |
+
"https://huggingface.co",
|
| 41 |
+
"https://*.hf.space", # Add this
|
| 42 |
+
# Add your production frontend URL here
|
| 43 |
+
# "https://your-frontend-domain.com",
|
| 44 |
+
],
|
| 45 |
+
allow_credentials=True,
|
| 46 |
+
allow_methods=["*"],
|
| 47 |
+
allow_headers=["*"],
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Initialize model
|
| 51 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 52 |
+
logging.info(f"Using device: {device}")
|
| 53 |
+
|
| 54 |
+
logging.info("Loading TripoSR model...")
|
| 55 |
+
model = TSR.from_pretrained(
|
| 56 |
+
"stabilityai/TripoSR",
|
| 57 |
+
config_name="config.yaml",
|
| 58 |
+
weight_name="model.ckpt",
|
| 59 |
+
)
|
| 60 |
+
model.renderer.set_chunk_size(8192)
|
| 61 |
+
model.to(device)
|
| 62 |
+
logging.info("Model loaded successfully!")
|
| 63 |
+
|
| 64 |
+
rembg_session = rembg.new_session()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class GenerateRequest(BaseModel):
|
| 68 |
+
do_remove_background: bool = True
|
| 69 |
+
foreground_ratio: float = 0.85
|
| 70 |
+
mc_resolution: int = 256
|
| 71 |
+
format: str = "obj" # obj or glb
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def preprocess_image(image: Image.Image, do_remove_background: bool, foreground_ratio: float) -> Image.Image:
|
| 75 |
+
"""Preprocess the input image."""
|
| 76 |
+
def fill_background(img):
|
| 77 |
+
img = np.array(img).astype(np.float32) / 255.0
|
| 78 |
+
img = img[:, :, :3] * img[:, :, 3:4] + (1 - img[:, :, 3:4]) * 0.5
|
| 79 |
+
return Image.fromarray((img * 255.0).astype(np.uint8))
|
| 80 |
+
|
| 81 |
+
if do_remove_background:
|
| 82 |
+
image = image.convert("RGB")
|
| 83 |
+
image = remove_background(image, rembg_session)
|
| 84 |
+
image = resize_foreground(image, foreground_ratio)
|
| 85 |
+
image = fill_background(image)
|
| 86 |
+
else:
|
| 87 |
+
if image.mode == "RGBA":
|
| 88 |
+
image = fill_background(image)
|
| 89 |
+
|
| 90 |
+
return image
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@app.get("/")
|
| 94 |
+
async def root():
|
| 95 |
+
return {"message": "TripoSR API is running", "device": device}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@app.get("/health")
|
| 99 |
+
async def health():
|
| 100 |
+
return {"status": "healthy", "device": device, "cuda_available": torch.cuda.is_available()}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@app.post("/generate")
|
| 104 |
+
async def generate_mesh(
|
| 105 |
+
image: UploadFile = File(...),
|
| 106 |
+
do_remove_background: bool = Form(True),
|
| 107 |
+
foreground_ratio: float = Form(0.85),
|
| 108 |
+
mc_resolution: int = Form(256),
|
| 109 |
+
format: str = Form("obj"),
|
| 110 |
+
bake_texture_flag: bool = Form(True),
|
| 111 |
+
texture_resolution: int = Form(2048),
|
| 112 |
+
orientation: str = Form("standard")
|
| 113 |
+
):
|
| 114 |
+
"""
|
| 115 |
+
Generate a 3D mesh from an uploaded image with optional texture baking.
|
| 116 |
+
|
| 117 |
+
Parameters:
|
| 118 |
+
- image: Image file (PNG, JPG, JPEG)
|
| 119 |
+
- do_remove_background: Whether to remove background (default: True)
|
| 120 |
+
- foreground_ratio: Ratio of foreground size (default: 0.85)
|
| 121 |
+
- mc_resolution: Marching cubes resolution (default: 256)
|
| 122 |
+
- format: Output format - "obj" or "glb" (default: "obj")
|
| 123 |
+
- bake_texture_flag: Whether to bake texture (default: True)
|
| 124 |
+
- texture_resolution: Texture atlas resolution (default: 2048)
|
| 125 |
+
- orientation: Mesh orientation - "standard" (Y-up, Z-forward), "gradio" (Gradio viewer), or "none" (original) (default: "standard")
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
- If bake_texture=True: ZIP file containing mesh and texture.png
|
| 129 |
+
- If bake_texture=False: Mesh file only
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
# Validate format
|
| 133 |
+
if format not in ["obj", "glb"]:
|
| 134 |
+
raise HTTPException(status_code=400, detail="Format must be 'obj' or 'glb'")
|
| 135 |
+
|
| 136 |
+
# Read and validate image
|
| 137 |
+
image_data = await image.read()
|
| 138 |
+
input_image = Image.open(BytesIO(image_data))
|
| 139 |
+
|
| 140 |
+
if input_image.mode not in ["RGB", "RGBA"]:
|
| 141 |
+
input_image = input_image.convert("RGB")
|
| 142 |
+
|
| 143 |
+
logging.info(f"Processing image: {image.filename}, size: {input_image.size}")
|
| 144 |
+
|
| 145 |
+
# Preprocess image
|
| 146 |
+
processed_image = preprocess_image(input_image, do_remove_background, foreground_ratio)
|
| 147 |
+
|
| 148 |
+
# Generate mesh
|
| 149 |
+
logging.info("Running model...")
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
scene_codes = model([processed_image], device=device)
|
| 152 |
+
|
| 153 |
+
# Check if xatlas is available for texture baking
|
| 154 |
+
xatlas_available = False
|
| 155 |
+
if bake_texture_flag:
|
| 156 |
+
try:
|
| 157 |
+
import xatlas
|
| 158 |
+
xatlas_available = True
|
| 159 |
+
logging.info("xatlas found, texture baking enabled")
|
| 160 |
+
except ImportError:
|
| 161 |
+
logging.warning("xatlas not available - texture baking requires xatlas. Using vertex colors instead.")
|
| 162 |
+
logging.warning("To enable texture baking, install: pip install xatlas==0.0.9 moderngl==5.10.0")
|
| 163 |
+
bake_texture_flag = False
|
| 164 |
+
|
| 165 |
+
# ALWAYS extract mesh with vertex colors (colors from the model trained on the image)
|
| 166 |
+
# has_vertex_color=True extracts colors that the model learned from your input image
|
| 167 |
+
# These colors represent the true colors from your image as interpreted by the trained model
|
| 168 |
+
logging.info("Extracting mesh with vertex colors (true colors from image)...")
|
| 169 |
+
meshes = model.extract_mesh(scene_codes, has_vertex_color=True, resolution=mc_resolution)
|
| 170 |
+
mesh = meshes[0]
|
| 171 |
+
|
| 172 |
+
# Apply orientation transformation for better 3D viewing
|
| 173 |
+
# Options: "standard" (Y-up, Z-forward), "gradio" (Gradio viewer), "none" (original)
|
| 174 |
+
if orientation not in ["standard", "gradio", "none"]:
|
| 175 |
+
orientation = "standard"
|
| 176 |
+
logging.info(f"Applying mesh orientation: {orientation}")
|
| 177 |
+
mesh = apply_mesh_orientation(mesh, orientation=orientation)
|
| 178 |
+
|
| 179 |
+
if bake_texture_flag and xatlas_available:
|
| 180 |
+
# Bake texture
|
| 181 |
+
logging.info("Baking texture...")
|
| 182 |
+
bake_output = bake_texture(mesh, model, scene_codes[0], texture_resolution)
|
| 183 |
+
|
| 184 |
+
# Save mesh with UV mapping
|
| 185 |
+
mesh_temp = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
| 186 |
+
xatlas.export(
|
| 187 |
+
mesh_temp.name,
|
| 188 |
+
mesh.vertices[bake_output["vmapping"]],
|
| 189 |
+
bake_output["indices"],
|
| 190 |
+
bake_output["uvs"],
|
| 191 |
+
mesh.vertex_normals[bake_output["vmapping"]]
|
| 192 |
+
)
|
| 193 |
+
mesh_temp.close()
|
| 194 |
+
|
| 195 |
+
# Save texture
|
| 196 |
+
texture_temp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 197 |
+
Image.fromarray(
|
| 198 |
+
(bake_output["colors"] * 255.0).astype(np.uint8)
|
| 199 |
+
).transpose(Image.FLIP_TOP_BOTTOM).save(texture_temp.name)
|
| 200 |
+
texture_temp.close()
|
| 201 |
+
|
| 202 |
+
# Create ZIP file with both mesh and texture
|
| 203 |
+
zip_temp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
|
| 204 |
+
with zipfile.ZipFile(zip_temp.name, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 205 |
+
zipf.write(mesh_temp.name, f"mesh.{format}")
|
| 206 |
+
zipf.write(texture_temp.name, "texture.png")
|
| 207 |
+
zip_temp.close()
|
| 208 |
+
|
| 209 |
+
logging.info(f"Mesh and texture saved to: {zip_temp.name}")
|
| 210 |
+
|
| 211 |
+
# Clean up individual files
|
| 212 |
+
os.unlink(mesh_temp.name)
|
| 213 |
+
os.unlink(texture_temp.name)
|
| 214 |
+
|
| 215 |
+
return FileResponse(
|
| 216 |
+
zip_temp.name,
|
| 217 |
+
media_type="application/zip",
|
| 218 |
+
filename="mesh_with_texture.zip",
|
| 219 |
+
headers={"X-File-Path": zip_temp.name}
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
# Save mesh with vertex colors (colors from the model, which learned from your image)
|
| 223 |
+
# The vertex colors represent the true colors from the input image as interpreted by the model
|
| 224 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
| 225 |
+
try:
|
| 226 |
+
mesh.export(temp_file.name)
|
| 227 |
+
except AttributeError as e:
|
| 228 |
+
if "ptp" in str(e) and format == "glb":
|
| 229 |
+
# Fallback to OBJ if GLB export fails due to NumPy 2.0 compatibility
|
| 230 |
+
logging.warning(f"GLB export failed due to NumPy compatibility, falling back to OBJ format")
|
| 231 |
+
temp_file.close()
|
| 232 |
+
os.unlink(temp_file.name)
|
| 233 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
|
| 234 |
+
format = "obj"
|
| 235 |
+
mesh.export(temp_file.name)
|
| 236 |
+
else:
|
| 237 |
+
raise
|
| 238 |
+
temp_file.close()
|
| 239 |
+
|
| 240 |
+
logging.info(f"Mesh with vertex colors (from image) saved to: {temp_file.name}")
|
| 241 |
+
|
| 242 |
+
return FileResponse(
|
| 243 |
+
temp_file.name,
|
| 244 |
+
media_type="application/octet-stream",
|
| 245 |
+
filename=f"mesh.{format}",
|
| 246 |
+
headers={"X-File-Path": temp_file.name}
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
logging.error(f"Error generating mesh: {str(e)}", exc_info=True)
|
| 251 |
+
raise HTTPException(status_code=500, detail=f"Error generating mesh: {str(e)}")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@app.post("/generate-base64")
|
| 255 |
+
async def generate_mesh_base64(
|
| 256 |
+
image: UploadFile = File(...),
|
| 257 |
+
do_remove_background: bool = Form(True),
|
| 258 |
+
foreground_ratio: float = Form(0.85),
|
| 259 |
+
mc_resolution: int = Form(256),
|
| 260 |
+
format: str = Form("obj"),
|
| 261 |
+
bake_texture_flag: bool = Form(True),
|
| 262 |
+
texture_resolution: int = Form(2048),
|
| 263 |
+
orientation: str = Form("standard")
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
Generate a 3D mesh with texture and return as base64 encoded strings.
|
| 267 |
+
Useful for frontend that wants to handle the mesh data directly.
|
| 268 |
+
|
| 269 |
+
Returns JSON with mesh and texture (if baked) as base64 strings.
|
| 270 |
+
"""
|
| 271 |
+
try:
|
| 272 |
+
if format not in ["obj", "glb"]:
|
| 273 |
+
raise HTTPException(status_code=400, detail="Format must be 'obj' or 'glb'")
|
| 274 |
+
|
| 275 |
+
# Read and validate image
|
| 276 |
+
image_data = await image.read()
|
| 277 |
+
input_image = Image.open(BytesIO(image_data))
|
| 278 |
+
|
| 279 |
+
if input_image.mode not in ["RGB", "RGBA"]:
|
| 280 |
+
input_image = input_image.convert("RGB")
|
| 281 |
+
|
| 282 |
+
logging.info(f"Processing image: {image.filename}")
|
| 283 |
+
|
| 284 |
+
# Preprocess image
|
| 285 |
+
processed_image = preprocess_image(input_image, do_remove_background, foreground_ratio)
|
| 286 |
+
|
| 287 |
+
# Generate mesh
|
| 288 |
+
logging.info("Running model...")
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
scene_codes = model([processed_image], device=device)
|
| 291 |
+
|
| 292 |
+
# Check if xatlas is available for texture baking
|
| 293 |
+
xatlas_available = False
|
| 294 |
+
if bake_texture_flag:
|
| 295 |
+
try:
|
| 296 |
+
import xatlas
|
| 297 |
+
xatlas_available = True
|
| 298 |
+
logging.info("xatlas found, texture baking enabled")
|
| 299 |
+
except ImportError:
|
| 300 |
+
logging.warning("xatlas not available - texture baking requires xatlas. Using vertex colors instead.")
|
| 301 |
+
bake_texture_flag = False
|
| 302 |
+
|
| 303 |
+
# ALWAYS extract mesh with vertex colors (colors from the model trained on the image)
|
| 304 |
+
# has_vertex_color=True extracts colors that the model learned from your input image
|
| 305 |
+
logging.info("Extracting mesh with vertex colors (true colors from image)...")
|
| 306 |
+
meshes = model.extract_mesh(scene_codes, has_vertex_color=True, resolution=mc_resolution)
|
| 307 |
+
mesh = meshes[0]
|
| 308 |
+
|
| 309 |
+
# Apply orientation transformation for better 3D viewing
|
| 310 |
+
# Options: "standard" (Y-up, Z-forward), "gradio" (Gradio viewer), "none" (original)
|
| 311 |
+
if orientation not in ["standard", "gradio", "none"]:
|
| 312 |
+
orientation = "standard"
|
| 313 |
+
logging.info(f"Applying mesh orientation: {orientation}")
|
| 314 |
+
mesh = apply_mesh_orientation(mesh, orientation=orientation)
|
| 315 |
+
|
| 316 |
+
if bake_texture_flag and xatlas_available:
|
| 317 |
+
# Bake texture (creates texture atlas from model colors)
|
| 318 |
+
logging.info("Baking texture...")
|
| 319 |
+
bake_output = bake_texture(mesh, model, scene_codes[0], texture_resolution)
|
| 320 |
+
|
| 321 |
+
# Save mesh with UV mapping
|
| 322 |
+
mesh_temp = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
| 323 |
+
xatlas.export(
|
| 324 |
+
mesh_temp.name,
|
| 325 |
+
mesh.vertices[bake_output["vmapping"]],
|
| 326 |
+
bake_output["indices"],
|
| 327 |
+
bake_output["uvs"],
|
| 328 |
+
mesh.vertex_normals[bake_output["vmapping"]]
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Save texture
|
| 332 |
+
texture_temp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
| 333 |
+
Image.fromarray(
|
| 334 |
+
(bake_output["colors"] * 255.0).astype(np.uint8)
|
| 335 |
+
).transpose(Image.FLIP_TOP_BOTTOM).save(texture_temp.name)
|
| 336 |
+
|
| 337 |
+
# Read and encode to base64
|
| 338 |
+
with open(mesh_temp.name, "rb") as f:
|
| 339 |
+
mesh_data = f.read()
|
| 340 |
+
with open(texture_temp.name, "rb") as f:
|
| 341 |
+
texture_data = f.read()
|
| 342 |
+
|
| 343 |
+
mesh_base64 = base64.b64encode(mesh_data).decode("utf-8")
|
| 344 |
+
texture_base64 = base64.b64encode(texture_data).decode("utf-8")
|
| 345 |
+
|
| 346 |
+
# Clean up
|
| 347 |
+
os.unlink(mesh_temp.name)
|
| 348 |
+
os.unlink(texture_temp.name)
|
| 349 |
+
|
| 350 |
+
return JSONResponse({
|
| 351 |
+
"success": True,
|
| 352 |
+
"format": format,
|
| 353 |
+
"mesh": mesh_base64,
|
| 354 |
+
"texture": texture_base64,
|
| 355 |
+
"mesh_size": len(mesh_data),
|
| 356 |
+
"texture_size": len(texture_data),
|
| 357 |
+
"has_texture": True
|
| 358 |
+
})
|
| 359 |
+
else:
|
| 360 |
+
# Save mesh without texture
|
| 361 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
| 362 |
+
try:
|
| 363 |
+
mesh.export(temp_file.name)
|
| 364 |
+
except AttributeError as e:
|
| 365 |
+
if "ptp" in str(e) and format == "glb":
|
| 366 |
+
# Fallback to OBJ if GLB export fails due to NumPy 2.0 compatibility
|
| 367 |
+
logging.warning(f"GLB export failed due to NumPy compatibility, falling back to OBJ format")
|
| 368 |
+
temp_file.close()
|
| 369 |
+
os.unlink(temp_file.name)
|
| 370 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
|
| 371 |
+
format = "obj"
|
| 372 |
+
mesh.export(temp_file.name)
|
| 373 |
+
else:
|
| 374 |
+
raise
|
| 375 |
+
|
| 376 |
+
# Read file and encode to base64
|
| 377 |
+
with open(temp_file.name, "rb") as f:
|
| 378 |
+
mesh_data = f.read()
|
| 379 |
+
|
| 380 |
+
mesh_base64 = base64.b64encode(mesh_data).decode("utf-8")
|
| 381 |
+
|
| 382 |
+
# Clean up
|
| 383 |
+
os.unlink(temp_file.name)
|
| 384 |
+
|
| 385 |
+
return JSONResponse({
|
| 386 |
+
"success": True,
|
| 387 |
+
"format": format,
|
| 388 |
+
"mesh": mesh_base64,
|
| 389 |
+
"mesh_size": len(mesh_data),
|
| 390 |
+
"has_texture": False
|
| 391 |
+
})
|
| 392 |
+
|
| 393 |
+
except Exception as e:
|
| 394 |
+
logging.error(f"Error generating mesh: {str(e)}", exc_info=True)
|
| 395 |
+
raise HTTPException(status_code=500, detail=f"Error generating mesh: {str(e)}")
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
if __name__ == "__main__":
|
| 399 |
+
import uvicorn
|
| 400 |
+
# Use 127.0.0.1 for localhost access, or 0.0.0.0 for network access
|
| 401 |
+
# For local development, use 127.0.0.1
|
| 402 |
+
uvicorn.run(app, host="127.0.0.1", port=8000, reload=False)
|
| 403 |
+
|
app.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
import rembg
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
from tsr.system import TSR
|
| 13 |
+
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if torch.cuda.is_available():
|
| 17 |
+
device = "cuda:0"
|
| 18 |
+
else:
|
| 19 |
+
device = "cpu"
|
| 20 |
+
|
| 21 |
+
model = TSR.from_pretrained(
|
| 22 |
+
"stabilityai/TripoSR",
|
| 23 |
+
config_name="config.yaml",
|
| 24 |
+
weight_name="model.ckpt",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# adjust the chunk size to balance between speed and memory usage
|
| 28 |
+
model.renderer.set_chunk_size(8192)
|
| 29 |
+
model.to(device)
|
| 30 |
+
|
| 31 |
+
rembg_session = rembg.new_session()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def check_input_image(input_image):
|
| 35 |
+
if input_image is None:
|
| 36 |
+
raise gr.Error("No image uploaded!")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def preprocess(input_image, do_remove_background, foreground_ratio):
|
| 40 |
+
def fill_background(image):
|
| 41 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 42 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
| 43 |
+
image = Image.fromarray((image * 255.0).astype(np.uint8))
|
| 44 |
+
return image
|
| 45 |
+
|
| 46 |
+
if do_remove_background:
|
| 47 |
+
image = input_image.convert("RGB")
|
| 48 |
+
image = remove_background(image, rembg_session)
|
| 49 |
+
image = resize_foreground(image, foreground_ratio)
|
| 50 |
+
image = fill_background(image)
|
| 51 |
+
else:
|
| 52 |
+
image = input_image
|
| 53 |
+
if image.mode == "RGBA":
|
| 54 |
+
image = fill_background(image)
|
| 55 |
+
return image
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def generate(image, mc_resolution, formats=["obj", "glb"]):
|
| 59 |
+
scene_codes = model(image, device=device)
|
| 60 |
+
mesh = model.extract_mesh(scene_codes, True, resolution=mc_resolution)[0]
|
| 61 |
+
mesh = to_gradio_3d_orientation(mesh)
|
| 62 |
+
rv = []
|
| 63 |
+
for format in formats:
|
| 64 |
+
mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
| 65 |
+
mesh.export(mesh_path.name)
|
| 66 |
+
rv.append(mesh_path.name)
|
| 67 |
+
return rv
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def run_example(image_pil):
|
| 71 |
+
preprocessed = preprocess(image_pil, False, 0.9)
|
| 72 |
+
mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
|
| 73 |
+
return preprocessed, mesh_name_obj, mesh_name_glb
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
with gr.Blocks(title="TripoSR") as interface:
|
| 77 |
+
gr.Markdown(
|
| 78 |
+
"""
|
| 79 |
+
# TripoSR Demo
|
| 80 |
+
[TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
|
| 81 |
+
|
| 82 |
+
**Tips:**
|
| 83 |
+
1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
|
| 84 |
+
2. It's better to disable "Remove Background" for the provided examples (except fot the last one) since they have been already preprocessed.
|
| 85 |
+
3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
|
| 86 |
+
"""
|
| 87 |
+
)
|
| 88 |
+
with gr.Row(variant="panel"):
|
| 89 |
+
with gr.Column():
|
| 90 |
+
with gr.Row():
|
| 91 |
+
input_image = gr.Image(
|
| 92 |
+
label="Input Image",
|
| 93 |
+
image_mode="RGBA",
|
| 94 |
+
sources="upload",
|
| 95 |
+
type="pil",
|
| 96 |
+
elem_id="content_image",
|
| 97 |
+
)
|
| 98 |
+
processed_image = gr.Image(label="Processed Image", interactive=False)
|
| 99 |
+
with gr.Row():
|
| 100 |
+
with gr.Group():
|
| 101 |
+
do_remove_background = gr.Checkbox(
|
| 102 |
+
label="Remove Background", value=True
|
| 103 |
+
)
|
| 104 |
+
foreground_ratio = gr.Slider(
|
| 105 |
+
label="Foreground Ratio",
|
| 106 |
+
minimum=0.5,
|
| 107 |
+
maximum=1.0,
|
| 108 |
+
value=0.85,
|
| 109 |
+
step=0.05,
|
| 110 |
+
)
|
| 111 |
+
mc_resolution = gr.Slider(
|
| 112 |
+
label="Marching Cubes Resolution",
|
| 113 |
+
minimum=32,
|
| 114 |
+
maximum=320,
|
| 115 |
+
value=256,
|
| 116 |
+
step=32
|
| 117 |
+
)
|
| 118 |
+
with gr.Row():
|
| 119 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
| 120 |
+
with gr.Column():
|
| 121 |
+
with gr.Tab("OBJ"):
|
| 122 |
+
output_model_obj = gr.Model3D(
|
| 123 |
+
label="Output Model (OBJ Format)",
|
| 124 |
+
interactive=False,
|
| 125 |
+
)
|
| 126 |
+
gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
|
| 127 |
+
with gr.Tab("GLB"):
|
| 128 |
+
output_model_glb = gr.Model3D(
|
| 129 |
+
label="Output Model (GLB Format)",
|
| 130 |
+
interactive=False,
|
| 131 |
+
)
|
| 132 |
+
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
| 133 |
+
with gr.Row(variant="panel"):
|
| 134 |
+
gr.Examples(
|
| 135 |
+
examples=[
|
| 136 |
+
"examples/hamburger.png",
|
| 137 |
+
"examples/poly_fox.png",
|
| 138 |
+
"examples/robot.png",
|
| 139 |
+
"examples/teapot.png",
|
| 140 |
+
"examples/tiger_girl.png",
|
| 141 |
+
"examples/horse.png",
|
| 142 |
+
"examples/flamingo.png",
|
| 143 |
+
"examples/unicorn.png",
|
| 144 |
+
"examples/chair.png",
|
| 145 |
+
"examples/iso_house.png",
|
| 146 |
+
"examples/marble.png",
|
| 147 |
+
"examples/police_woman.png",
|
| 148 |
+
"examples/captured.jpeg",
|
| 149 |
+
],
|
| 150 |
+
inputs=[input_image],
|
| 151 |
+
outputs=[processed_image, output_model_obj, output_model_glb],
|
| 152 |
+
cache_examples=False,
|
| 153 |
+
fn=partial(run_example),
|
| 154 |
+
label="Examples",
|
| 155 |
+
examples_per_page=20,
|
| 156 |
+
)
|
| 157 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
| 158 |
+
fn=preprocess,
|
| 159 |
+
inputs=[input_image, do_remove_background, foreground_ratio],
|
| 160 |
+
outputs=[processed_image],
|
| 161 |
+
).success(
|
| 162 |
+
fn=generate,
|
| 163 |
+
inputs=[processed_image, mc_resolution],
|
| 164 |
+
outputs=[output_model_obj, output_model_glb],
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# For Hugging Face Spaces, we just need to assign the interface
|
| 169 |
+
# The launch() is handled automatically by Spaces
|
| 170 |
+
app = interface
|
| 171 |
+
|
deploy_colab.ipynb
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "markdown",
|
| 21 |
+
"source": [
|
| 22 |
+
"# 🎨 TripoSR API on Google Colab\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"This notebook deploys the TripoSR API on Google Colab with a public URL using ngrok.\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"## 📋 Setup Instructions\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"1. **Enable GPU**: Runtime → Change runtime type → GPU (T4)\n",
|
| 29 |
+
"2. **Get ngrok token**: Sign up at [ngrok.com](https://ngrok.com) and get your authtoken\n",
|
| 30 |
+
"3. **Run all cells** in order\n",
|
| 31 |
+
"4. **Copy the public URL** from the output\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"⚠️ **Note**: The session will expire after 12 hours of inactivity or 24 hours maximum."
|
| 34 |
+
],
|
| 35 |
+
"metadata": {
|
| 36 |
+
"id": "header"
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"source": [
|
| 42 |
+
"# Check GPU availability\n",
|
| 43 |
+
"!nvidia-smi"
|
| 44 |
+
],
|
| 45 |
+
"metadata": {
|
| 46 |
+
"id": "check_gpu"
|
| 47 |
+
},
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"outputs": []
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"source": [
|
| 54 |
+
"# Clone the repository\n",
|
| 55 |
+
"!git clone https://github.com/Ahmedbelaid/TripoSR-api.git\n",
|
| 56 |
+
"%cd TripoSR-api"
|
| 57 |
+
],
|
| 58 |
+
"metadata": {
|
| 59 |
+
"id": "clone_repo"
|
| 60 |
+
},
|
| 61 |
+
"execution_count": null,
|
| 62 |
+
"outputs": []
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"source": [
|
| 67 |
+
"# Install dependencies\n",
|
| 68 |
+
"!pip install -q torch torchvision\n",
|
| 69 |
+
"!pip install -q -r requirements.txt"
|
| 70 |
+
],
|
| 71 |
+
"metadata": {
|
| 72 |
+
"id": "install_deps"
|
| 73 |
+
},
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"outputs": []
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "code",
|
| 79 |
+
"source": [
|
| 80 |
+
"# Install ngrok for public URL\n",
|
| 81 |
+
"!pip install -q pyngrok\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"from pyngrok import ngrok, conf\n",
|
| 84 |
+
"import getpass\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"# Get ngrok authtoken\n",
|
| 87 |
+
"print(\"Get your authtoken from: https://dashboard.ngrok.com/get-started/your-authtoken\")\n",
|
| 88 |
+
"authtoken = getpass.getpass(\"Enter your ngrok authtoken: \")\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"# Set ngrok authtoken\n",
|
| 91 |
+
"conf.get_default().auth_token = authtoken"
|
| 92 |
+
],
|
| 93 |
+
"metadata": {
|
| 94 |
+
"id": "setup_ngrok"
|
| 95 |
+
},
|
| 96 |
+
"execution_count": null,
|
| 97 |
+
"outputs": []
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"cell_type": "code",
|
| 101 |
+
"source": [
|
| 102 |
+
"# Start the API server in the background\n",
|
| 103 |
+
"import subprocess\n",
|
| 104 |
+
"import time\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# Start server\n",
|
| 107 |
+
"process = subprocess.Popen(\n",
|
| 108 |
+
" [\"python\", \"-m\", \"uvicorn\", \"api_server:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8000\"],\n",
|
| 109 |
+
" stdout=subprocess.PIPE,\n",
|
| 110 |
+
" stderr=subprocess.PIPE\n",
|
| 111 |
+
")\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"# Wait for server to start\n",
|
| 114 |
+
"print(\"Starting server...\")\n",
|
| 115 |
+
"time.sleep(10)\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"# Create ngrok tunnel\n",
|
| 118 |
+
"public_url = ngrok.connect(8000)\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 121 |
+
"print(\"🚀 TripoSR API is now running!\")\n",
|
| 122 |
+
"print(\"=\"*60)\n",
|
| 123 |
+
"print(f\"\\n📡 Public URL: {public_url}\")\n",
|
| 124 |
+
"print(f\"\\n🔍 Health Check: {public_url}/health\")\n",
|
| 125 |
+
"print(f\"\\n📝 API Docs: {public_url}/docs\")\n",
|
| 126 |
+
"print(\"\\n\" + \"=\"*60)\n",
|
| 127 |
+
"print(\"\\n⚠️ Keep this notebook running to keep the API active\")\n",
|
| 128 |
+
"print(\"⚠️ Session will expire after 12 hours of inactivity\\n\")"
|
| 129 |
+
],
|
| 130 |
+
"metadata": {
|
| 131 |
+
"id": "start_server"
|
| 132 |
+
},
|
| 133 |
+
"execution_count": null,
|
| 134 |
+
"outputs": []
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"source": [
|
| 139 |
+
"# Test the API\n",
|
| 140 |
+
"import requests\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"# Health check\n",
|
| 143 |
+
"response = requests.get(f\"{public_url}/health\")\n",
|
| 144 |
+
"print(\"Health Check Response:\")\n",
|
| 145 |
+
"print(response.json())"
|
| 146 |
+
],
|
| 147 |
+
"metadata": {
|
| 148 |
+
"id": "test_api"
|
| 149 |
+
},
|
| 150 |
+
"execution_count": null,
|
| 151 |
+
"outputs": []
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "markdown",
|
| 155 |
+
"source": [
|
| 156 |
+
"## 🧪 Example: Generate 3D Model\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"Upload an image and generate a 3D model:"
|
| 159 |
+
],
|
| 160 |
+
"metadata": {
|
| 161 |
+
"id": "example_header"
|
| 162 |
+
}
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"source": [
|
| 167 |
+
"from google.colab import files\n",
|
| 168 |
+
"import requests\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"# Upload an image\n",
|
| 171 |
+
"print(\"Upload an image file:\")\n",
|
| 172 |
+
"uploaded = files.upload()\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"# Get the filename\n",
|
| 175 |
+
"filename = list(uploaded.keys())[0]\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"# Generate 3D model\n",
|
| 178 |
+
"print(f\"\\nGenerating 3D model from {filename}...\")\n",
|
| 179 |
+
"with open(filename, 'rb') as f:\n",
|
| 180 |
+
" files_dict = {'image': f}\n",
|
| 181 |
+
" data = {\n",
|
| 182 |
+
" 'format': 'obj',\n",
|
| 183 |
+
" 'bake_texture_flag': True,\n",
|
| 184 |
+
" 'mc_resolution': 256\n",
|
| 185 |
+
" }\n",
|
| 186 |
+
" response = requests.post(f\"{public_url}/generate\", files=files_dict, data=data)\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"if response.status_code == 200:\n",
|
| 189 |
+
" # Save the output\n",
|
| 190 |
+
" output_filename = 'output.zip'\n",
|
| 191 |
+
" with open(output_filename, 'wb') as f:\n",
|
| 192 |
+
" f.write(response.content)\n",
|
| 193 |
+
" print(f\"\\n✅ Success! 3D model saved to {output_filename}\")\n",
|
| 194 |
+
" \n",
|
| 195 |
+
" # Download the file\n",
|
| 196 |
+
" files.download(output_filename)\n",
|
| 197 |
+
"else:\n",
|
| 198 |
+
" print(f\"\\n❌ Error: {response.status_code}\")\n",
|
| 199 |
+
" print(response.text)"
|
| 200 |
+
],
|
| 201 |
+
"metadata": {
|
| 202 |
+
"id": "generate_example"
|
| 203 |
+
},
|
| 204 |
+
"execution_count": null,
|
| 205 |
+
"outputs": []
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"cell_type": "markdown",
|
| 209 |
+
"source": [
|
| 210 |
+
"## 📊 Monitor Server Logs\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"Run this cell to see server logs:"
|
| 213 |
+
],
|
| 214 |
+
"metadata": {
|
| 215 |
+
"id": "logs_header"
|
| 216 |
+
}
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"cell_type": "code",
|
| 220 |
+
"source": [
|
| 221 |
+
"# View server logs\n",
|
| 222 |
+
"import time\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"print(\"Server logs (press Stop to exit):\")\n",
|
| 225 |
+
"print(\"=\"*60)\n",
|
| 226 |
+
"\n",
|
| 227 |
+
"while True:\n",
|
| 228 |
+
" output = process.stdout.readline()\n",
|
| 229 |
+
" if output:\n",
|
| 230 |
+
" print(output.decode().strip())\n",
|
| 231 |
+
" error = process.stderr.readline()\n",
|
| 232 |
+
" if error:\n",
|
| 233 |
+
" print(error.decode().strip())\n",
|
| 234 |
+
" time.sleep(0.1)"
|
| 235 |
+
],
|
| 236 |
+
"metadata": {
|
| 237 |
+
"id": "view_logs"
|
| 238 |
+
},
|
| 239 |
+
"execution_count": null,
|
| 240 |
+
"outputs": []
|
| 241 |
+
},
|
| 242 |
+
{
|
| 243 |
+
"cell_type": "markdown",
|
| 244 |
+
"source": [
|
| 245 |
+
"## 🛑 Stop Server\n",
|
| 246 |
+
"\n",
|
| 247 |
+
"Run this cell to stop the server:"
|
| 248 |
+
],
|
| 249 |
+
"metadata": {
|
| 250 |
+
"id": "stop_header"
|
| 251 |
+
}
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"cell_type": "code",
|
| 255 |
+
"source": [
|
| 256 |
+
"# Stop the server\n",
|
| 257 |
+
"process.terminate()\n",
|
| 258 |
+
"ngrok.disconnect(public_url)\n",
|
| 259 |
+
"print(\"Server stopped.\")"
|
| 260 |
+
],
|
| 261 |
+
"metadata": {
|
| 262 |
+
"id": "stop_server"
|
| 263 |
+
},
|
| 264 |
+
"execution_count": null,
|
| 265 |
+
"outputs": []
|
| 266 |
+
}
|
| 267 |
+
]
|
| 268 |
+
}
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
|
| 3 |
+
services:
|
| 4 |
+
triposr-api:
|
| 5 |
+
build:
|
| 6 |
+
context: .
|
| 7 |
+
dockerfile: Dockerfile
|
| 8 |
+
container_name: triposr-api
|
| 9 |
+
ports:
|
| 10 |
+
- "8000:8000"
|
| 11 |
+
environment:
|
| 12 |
+
- CUDA_VISIBLE_DEVICES=0 # Use first GPU, set to empty string to use CPU
|
| 13 |
+
volumes:
|
| 14 |
+
# Optional: Mount a volume for output files if you want to persist them
|
| 15 |
+
- ./output:/app/output
|
| 16 |
+
# Uncomment the following lines if you have NVIDIA GPU and want to use it
|
| 17 |
+
# deploy:
|
| 18 |
+
# resources:
|
| 19 |
+
# reservations:
|
| 20 |
+
# devices:
|
| 21 |
+
# - driver: nvidia
|
| 22 |
+
# count: 1
|
| 23 |
+
# capabilities: [gpu]
|
| 24 |
+
restart: unless-stopped
|
| 25 |
+
healthcheck:
|
| 26 |
+
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
| 27 |
+
interval: 30s
|
| 28 |
+
timeout: 10s
|
| 29 |
+
retries: 3
|
| 30 |
+
start_period: 60s
|
| 31 |
+
|
gradio_app.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import rembg
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from functools import partial
|
| 12 |
+
|
| 13 |
+
from tsr.system import TSR
|
| 14 |
+
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
if torch.cuda.is_available():
|
| 20 |
+
device = "cuda:0"
|
| 21 |
+
else:
|
| 22 |
+
device = "cpu"
|
| 23 |
+
|
| 24 |
+
model = TSR.from_pretrained(
|
| 25 |
+
"stabilityai/TripoSR",
|
| 26 |
+
config_name="config.yaml",
|
| 27 |
+
weight_name="model.ckpt",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# adjust the chunk size to balance between speed and memory usage
|
| 31 |
+
model.renderer.set_chunk_size(8192)
|
| 32 |
+
model.to(device)
|
| 33 |
+
|
| 34 |
+
rembg_session = rembg.new_session()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def check_input_image(input_image):
|
| 38 |
+
if input_image is None:
|
| 39 |
+
raise gr.Error("No image uploaded!")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def preprocess(input_image, do_remove_background, foreground_ratio):
|
| 43 |
+
def fill_background(image):
|
| 44 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 45 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
| 46 |
+
image = Image.fromarray((image * 255.0).astype(np.uint8))
|
| 47 |
+
return image
|
| 48 |
+
|
| 49 |
+
if do_remove_background:
|
| 50 |
+
image = input_image.convert("RGB")
|
| 51 |
+
image = remove_background(image, rembg_session)
|
| 52 |
+
image = resize_foreground(image, foreground_ratio)
|
| 53 |
+
image = fill_background(image)
|
| 54 |
+
else:
|
| 55 |
+
image = input_image
|
| 56 |
+
if image.mode == "RGBA":
|
| 57 |
+
image = fill_background(image)
|
| 58 |
+
return image
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def generate(image, mc_resolution, formats=["obj", "glb"]):
|
| 62 |
+
scene_codes = model(image, device=device)
|
| 63 |
+
mesh = model.extract_mesh(scene_codes, True, resolution=mc_resolution)[0]
|
| 64 |
+
mesh = to_gradio_3d_orientation(mesh)
|
| 65 |
+
rv = []
|
| 66 |
+
for format in formats:
|
| 67 |
+
mesh_path = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
| 68 |
+
mesh.export(mesh_path.name)
|
| 69 |
+
rv.append(mesh_path.name)
|
| 70 |
+
return rv
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def run_example(image_pil):
|
| 74 |
+
preprocessed = preprocess(image_pil, False, 0.9)
|
| 75 |
+
mesh_name_obj, mesh_name_glb = generate(preprocessed, 256, ["obj", "glb"])
|
| 76 |
+
return preprocessed, mesh_name_obj, mesh_name_glb
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
with gr.Blocks(title="TripoSR") as interface:
|
| 80 |
+
gr.Markdown(
|
| 81 |
+
"""
|
| 82 |
+
# TripoSR Demo
|
| 83 |
+
[TripoSR](https://github.com/VAST-AI-Research/TripoSR) is a state-of-the-art open-source model for **fast** feedforward 3D reconstruction from a single image, collaboratively developed by [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/).
|
| 84 |
+
|
| 85 |
+
**Tips:**
|
| 86 |
+
1. If you find the result is unsatisfied, please try to change the foreground ratio. It might improve the results.
|
| 87 |
+
2. It's better to disable "Remove Background" for the provided examples (except fot the last one) since they have been already preprocessed.
|
| 88 |
+
3. Otherwise, please disable "Remove Background" option only if your input image is RGBA with transparent background, image contents are centered and occupy more than 70% of image width or height.
|
| 89 |
+
"""
|
| 90 |
+
)
|
| 91 |
+
with gr.Row(variant="panel"):
|
| 92 |
+
with gr.Column():
|
| 93 |
+
with gr.Row():
|
| 94 |
+
input_image = gr.Image(
|
| 95 |
+
label="Input Image",
|
| 96 |
+
image_mode="RGBA",
|
| 97 |
+
sources="upload",
|
| 98 |
+
type="pil",
|
| 99 |
+
elem_id="content_image",
|
| 100 |
+
)
|
| 101 |
+
processed_image = gr.Image(label="Processed Image", interactive=False)
|
| 102 |
+
with gr.Row():
|
| 103 |
+
with gr.Group():
|
| 104 |
+
do_remove_background = gr.Checkbox(
|
| 105 |
+
label="Remove Background", value=True
|
| 106 |
+
)
|
| 107 |
+
foreground_ratio = gr.Slider(
|
| 108 |
+
label="Foreground Ratio",
|
| 109 |
+
minimum=0.5,
|
| 110 |
+
maximum=1.0,
|
| 111 |
+
value=0.85,
|
| 112 |
+
step=0.05,
|
| 113 |
+
)
|
| 114 |
+
mc_resolution = gr.Slider(
|
| 115 |
+
label="Marching Cubes Resolution",
|
| 116 |
+
minimum=32,
|
| 117 |
+
maximum=320,
|
| 118 |
+
value=256,
|
| 119 |
+
step=32
|
| 120 |
+
)
|
| 121 |
+
with gr.Row():
|
| 122 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
| 123 |
+
with gr.Column():
|
| 124 |
+
with gr.Tab("OBJ"):
|
| 125 |
+
output_model_obj = gr.Model3D(
|
| 126 |
+
label="Output Model (OBJ Format)",
|
| 127 |
+
interactive=False,
|
| 128 |
+
)
|
| 129 |
+
gr.Markdown("Note: The model shown here is flipped. Download to get correct results.")
|
| 130 |
+
with gr.Tab("GLB"):
|
| 131 |
+
output_model_glb = gr.Model3D(
|
| 132 |
+
label="Output Model (GLB Format)",
|
| 133 |
+
interactive=False,
|
| 134 |
+
)
|
| 135 |
+
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
| 136 |
+
with gr.Row(variant="panel"):
|
| 137 |
+
gr.Examples(
|
| 138 |
+
examples=[
|
| 139 |
+
"examples/hamburger.png",
|
| 140 |
+
"examples/poly_fox.png",
|
| 141 |
+
"examples/robot.png",
|
| 142 |
+
"examples/teapot.png",
|
| 143 |
+
"examples/tiger_girl.png",
|
| 144 |
+
"examples/horse.png",
|
| 145 |
+
"examples/flamingo.png",
|
| 146 |
+
"examples/unicorn.png",
|
| 147 |
+
"examples/chair.png",
|
| 148 |
+
"examples/iso_house.png",
|
| 149 |
+
"examples/marble.png",
|
| 150 |
+
"examples/police_woman.png",
|
| 151 |
+
"examples/captured.jpeg",
|
| 152 |
+
],
|
| 153 |
+
inputs=[input_image],
|
| 154 |
+
outputs=[processed_image, output_model_obj, output_model_glb],
|
| 155 |
+
cache_examples=False,
|
| 156 |
+
fn=partial(run_example),
|
| 157 |
+
label="Examples",
|
| 158 |
+
examples_per_page=20,
|
| 159 |
+
)
|
| 160 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
| 161 |
+
fn=preprocess,
|
| 162 |
+
inputs=[input_image, do_remove_background, foreground_ratio],
|
| 163 |
+
outputs=[processed_image],
|
| 164 |
+
).success(
|
| 165 |
+
fn=generate,
|
| 166 |
+
inputs=[processed_image, mc_resolution],
|
| 167 |
+
outputs=[output_model_obj, output_model_glb],
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
if __name__ == '__main__':
|
| 173 |
+
parser = argparse.ArgumentParser()
|
| 174 |
+
parser.add_argument('--username', type=str, default=None, help='Username for authentication')
|
| 175 |
+
parser.add_argument('--password', type=str, default=None, help='Password for authentication')
|
| 176 |
+
parser.add_argument('--port', type=int, default=7860, help='Port to run the server listener on')
|
| 177 |
+
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
| 178 |
+
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
| 179 |
+
parser.add_argument("--queuesize", type=int, default=1, help="launch gradio queue max_size")
|
| 180 |
+
args = parser.parse_args()
|
| 181 |
+
interface.queue(max_size=args.queuesize)
|
| 182 |
+
interface.launch(
|
| 183 |
+
auth=(args.username, args.password) if (args.username and args.password) else None,
|
| 184 |
+
share=args.share,
|
| 185 |
+
server_name="0.0.0.0" if args.listen else None,
|
| 186 |
+
server_port=args.port
|
| 187 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
omegaconf==2.3.0
|
| 3 |
+
Pillow==10.1.0
|
| 4 |
+
einops==0.7.0
|
| 5 |
+
git+https://github.com/tatsy/torchmcubes.git
|
| 6 |
+
transformers==4.35.0
|
| 7 |
+
trimesh==4.0.5
|
| 8 |
+
rembg
|
| 9 |
+
huggingface-hub
|
| 10 |
+
imageio[ffmpeg]
|
| 11 |
+
gradio==3.50.2
|
| 12 |
+
xatlas==0.0.9
|
| 13 |
+
moderngl==5.10.0
|
| 14 |
+
fastapi>=0.100.0
|
| 15 |
+
uvicorn[standard]>=0.23.0
|
| 16 |
+
python-multipart
|
run.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import rembg
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from tsr.system import TSR
|
| 12 |
+
from tsr.utils import remove_background, resize_foreground, save_video
|
| 13 |
+
from tsr.bake_texture import bake_texture
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Timer:
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.items = {}
|
| 19 |
+
self.time_scale = 1000.0 # ms
|
| 20 |
+
self.time_unit = "ms"
|
| 21 |
+
|
| 22 |
+
def start(self, name: str) -> None:
|
| 23 |
+
if torch.cuda.is_available():
|
| 24 |
+
torch.cuda.synchronize()
|
| 25 |
+
self.items[name] = time.time()
|
| 26 |
+
logging.info(f"{name} ...")
|
| 27 |
+
|
| 28 |
+
def end(self, name: str) -> float:
|
| 29 |
+
if name not in self.items:
|
| 30 |
+
return
|
| 31 |
+
if torch.cuda.is_available():
|
| 32 |
+
torch.cuda.synchronize()
|
| 33 |
+
start_time = self.items.pop(name)
|
| 34 |
+
delta = time.time() - start_time
|
| 35 |
+
t = delta * self.time_scale
|
| 36 |
+
logging.info(f"{name} finished in {t:.2f}{self.time_unit}.")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
timer = Timer()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logging.basicConfig(
|
| 43 |
+
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
|
| 44 |
+
)
|
| 45 |
+
parser = argparse.ArgumentParser()
|
| 46 |
+
parser.add_argument("image", type=str, nargs="+", help="Path to input image(s).")
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--device",
|
| 49 |
+
default="cuda:0",
|
| 50 |
+
type=str,
|
| 51 |
+
help="Device to use. If no CUDA-compatible device is found, will fallback to 'cpu'. Default: 'cuda:0'",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--pretrained-model-name-or-path",
|
| 55 |
+
default="stabilityai/TripoSR",
|
| 56 |
+
type=str,
|
| 57 |
+
help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/TripoSR'",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--chunk-size",
|
| 61 |
+
default=8192,
|
| 62 |
+
type=int,
|
| 63 |
+
help="Evaluation chunk size for surface extraction and rendering. Smaller chunk size reduces VRAM usage but increases computation time. 0 for no chunking. Default: 8192",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--mc-resolution",
|
| 67 |
+
default=256,
|
| 68 |
+
type=int,
|
| 69 |
+
help="Marching cubes grid resolution. Default: 256"
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--no-remove-bg",
|
| 73 |
+
action="store_true",
|
| 74 |
+
help="If specified, the background will NOT be automatically removed from the input image, and the input image should be an RGB image with gray background and properly-sized foreground. Default: false",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--foreground-ratio",
|
| 78 |
+
default=0.85,
|
| 79 |
+
type=float,
|
| 80 |
+
help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--output-dir",
|
| 84 |
+
default="output/",
|
| 85 |
+
type=str,
|
| 86 |
+
help="Output directory to save the results. Default: 'output/'",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--model-save-format",
|
| 90 |
+
default="obj",
|
| 91 |
+
type=str,
|
| 92 |
+
choices=["obj", "glb"],
|
| 93 |
+
help="Format to save the extracted mesh. Default: 'obj'",
|
| 94 |
+
)
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--bake-texture",
|
| 97 |
+
action="store_true",
|
| 98 |
+
help="Bake a texture atlas for the extracted mesh, instead of vertex colors",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--texture-resolution",
|
| 102 |
+
default=2048,
|
| 103 |
+
type=int,
|
| 104 |
+
help="Texture atlas resolution, only useful with --bake-texture. Default: 2048"
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--render",
|
| 108 |
+
action="store_true",
|
| 109 |
+
help="If specified, save a NeRF-rendered video. Default: false",
|
| 110 |
+
)
|
| 111 |
+
args = parser.parse_args()
|
| 112 |
+
|
| 113 |
+
output_dir = args.output_dir
|
| 114 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 115 |
+
|
| 116 |
+
device = args.device
|
| 117 |
+
if not torch.cuda.is_available():
|
| 118 |
+
device = "cpu"
|
| 119 |
+
|
| 120 |
+
timer.start("Initializing model")
|
| 121 |
+
model = TSR.from_pretrained(
|
| 122 |
+
args.pretrained_model_name_or_path,
|
| 123 |
+
config_name="config.yaml",
|
| 124 |
+
weight_name="model.ckpt",
|
| 125 |
+
)
|
| 126 |
+
model.renderer.set_chunk_size(args.chunk_size)
|
| 127 |
+
model.to(device)
|
| 128 |
+
timer.end("Initializing model")
|
| 129 |
+
|
| 130 |
+
timer.start("Processing images")
|
| 131 |
+
images = []
|
| 132 |
+
|
| 133 |
+
if args.no_remove_bg:
|
| 134 |
+
rembg_session = None
|
| 135 |
+
else:
|
| 136 |
+
rembg_session = rembg.new_session()
|
| 137 |
+
|
| 138 |
+
for i, image_path in enumerate(args.image):
|
| 139 |
+
if args.no_remove_bg:
|
| 140 |
+
image = np.array(Image.open(image_path).convert("RGB"))
|
| 141 |
+
else:
|
| 142 |
+
image = remove_background(Image.open(image_path), rembg_session)
|
| 143 |
+
image = resize_foreground(image, args.foreground_ratio)
|
| 144 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 145 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
| 146 |
+
image = Image.fromarray((image * 255.0).astype(np.uint8))
|
| 147 |
+
if not os.path.exists(os.path.join(output_dir, str(i))):
|
| 148 |
+
os.makedirs(os.path.join(output_dir, str(i)))
|
| 149 |
+
image.save(os.path.join(output_dir, str(i), f"input.png"))
|
| 150 |
+
images.append(image)
|
| 151 |
+
timer.end("Processing images")
|
| 152 |
+
|
| 153 |
+
for i, image in enumerate(images):
|
| 154 |
+
logging.info(f"Running image {i + 1}/{len(images)} ...")
|
| 155 |
+
|
| 156 |
+
timer.start("Running model")
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
scene_codes = model([image], device=device)
|
| 159 |
+
timer.end("Running model")
|
| 160 |
+
|
| 161 |
+
if args.render:
|
| 162 |
+
timer.start("Rendering")
|
| 163 |
+
render_images = model.render(scene_codes, n_views=30, return_type="pil")
|
| 164 |
+
for ri, render_image in enumerate(render_images[0]):
|
| 165 |
+
render_image.save(os.path.join(output_dir, str(i), f"render_{ri:03d}.png"))
|
| 166 |
+
save_video(
|
| 167 |
+
render_images[0], os.path.join(output_dir, str(i), f"render.mp4"), fps=30
|
| 168 |
+
)
|
| 169 |
+
timer.end("Rendering")
|
| 170 |
+
|
| 171 |
+
timer.start("Extracting mesh")
|
| 172 |
+
meshes = model.extract_mesh(scene_codes, not args.bake_texture, resolution=args.mc_resolution)
|
| 173 |
+
timer.end("Extracting mesh")
|
| 174 |
+
|
| 175 |
+
out_mesh_path = os.path.join(output_dir, str(i), f"mesh.{args.model_save_format}")
|
| 176 |
+
if args.bake_texture:
|
| 177 |
+
try:
|
| 178 |
+
import xatlas
|
| 179 |
+
except ImportError:
|
| 180 |
+
raise ImportError(
|
| 181 |
+
"xatlas is required for texture baking. Please install it with: pip install xatlas==0.0.9\n"
|
| 182 |
+
"Note: This requires Microsoft Visual C++ Build Tools to compile."
|
| 183 |
+
)
|
| 184 |
+
out_texture_path = os.path.join(output_dir, str(i), "texture.png")
|
| 185 |
+
|
| 186 |
+
timer.start("Baking texture")
|
| 187 |
+
bake_output = bake_texture(meshes[0], model, scene_codes[0], args.texture_resolution)
|
| 188 |
+
timer.end("Baking texture")
|
| 189 |
+
|
| 190 |
+
timer.start("Exporting mesh and texture")
|
| 191 |
+
xatlas.export(out_mesh_path, meshes[0].vertices[bake_output["vmapping"]], bake_output["indices"], bake_output["uvs"], meshes[0].vertex_normals[bake_output["vmapping"]])
|
| 192 |
+
Image.fromarray((bake_output["colors"] * 255.0).astype(np.uint8)).transpose(Image.FLIP_TOP_BOTTOM).save(out_texture_path)
|
| 193 |
+
timer.end("Exporting mesh and texture")
|
| 194 |
+
else:
|
| 195 |
+
timer.start("Exporting mesh")
|
| 196 |
+
meshes[0].export(out_mesh_path)
|
| 197 |
+
timer.end("Exporting mesh")
|
tsr/__pycache__/bake_texture.cpython-313.pyc
ADDED
|
Binary file (7.68 kB). View file
|
|
|
tsr/__pycache__/system.cpython-313.pyc
ADDED
|
Binary file (9.9 kB). View file
|
|
|
tsr/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (24 kB). View file
|
|
|
tsr/bake_texture.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import trimesh
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import xatlas
|
| 8 |
+
import moderngl
|
| 9 |
+
_HAS_XATLAS = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
_HAS_XATLAS = False
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def make_atlas(mesh, texture_resolution, texture_padding):
|
| 15 |
+
if not _HAS_XATLAS:
|
| 16 |
+
raise ImportError(
|
| 17 |
+
"xatlas is required for texture baking. Please install it with: pip install xatlas==0.0.9\n"
|
| 18 |
+
"Note: This requires Microsoft Visual C++ Build Tools to compile."
|
| 19 |
+
)
|
| 20 |
+
atlas = xatlas.Atlas()
|
| 21 |
+
atlas.add_mesh(mesh.vertices, mesh.faces)
|
| 22 |
+
options = xatlas.PackOptions()
|
| 23 |
+
options.resolution = texture_resolution
|
| 24 |
+
options.padding = texture_padding
|
| 25 |
+
options.bilinear = True
|
| 26 |
+
atlas.generate(pack_options=options)
|
| 27 |
+
vmapping, indices, uvs = atlas[0]
|
| 28 |
+
return {
|
| 29 |
+
"vmapping": vmapping,
|
| 30 |
+
"indices": indices,
|
| 31 |
+
"uvs": uvs,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def rasterize_position_atlas(
|
| 36 |
+
mesh, atlas_vmapping, atlas_indices, atlas_uvs, texture_resolution, texture_padding
|
| 37 |
+
):
|
| 38 |
+
if not _HAS_XATLAS:
|
| 39 |
+
raise ImportError(
|
| 40 |
+
"moderngl is required for texture baking. Please install it with: pip install moderngl==5.10.0\n"
|
| 41 |
+
"Note: This requires Microsoft Visual C++ Build Tools to compile."
|
| 42 |
+
)
|
| 43 |
+
ctx = moderngl.create_context(standalone=True)
|
| 44 |
+
basic_prog = ctx.program(
|
| 45 |
+
vertex_shader="""
|
| 46 |
+
#version 330
|
| 47 |
+
in vec2 in_uv;
|
| 48 |
+
in vec3 in_pos;
|
| 49 |
+
out vec3 v_pos;
|
| 50 |
+
void main() {
|
| 51 |
+
v_pos = in_pos;
|
| 52 |
+
gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
|
| 53 |
+
}
|
| 54 |
+
""",
|
| 55 |
+
fragment_shader="""
|
| 56 |
+
#version 330
|
| 57 |
+
in vec3 v_pos;
|
| 58 |
+
out vec4 o_col;
|
| 59 |
+
void main() {
|
| 60 |
+
o_col = vec4(v_pos, 1.0);
|
| 61 |
+
}
|
| 62 |
+
""",
|
| 63 |
+
)
|
| 64 |
+
gs_prog = ctx.program(
|
| 65 |
+
vertex_shader="""
|
| 66 |
+
#version 330
|
| 67 |
+
in vec2 in_uv;
|
| 68 |
+
in vec3 in_pos;
|
| 69 |
+
out vec3 vg_pos;
|
| 70 |
+
void main() {
|
| 71 |
+
vg_pos = in_pos;
|
| 72 |
+
gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
|
| 73 |
+
}
|
| 74 |
+
""",
|
| 75 |
+
geometry_shader="""
|
| 76 |
+
#version 330
|
| 77 |
+
uniform float u_resolution;
|
| 78 |
+
uniform float u_dilation;
|
| 79 |
+
layout (triangles) in;
|
| 80 |
+
layout (triangle_strip, max_vertices = 12) out;
|
| 81 |
+
in vec3 vg_pos[];
|
| 82 |
+
out vec3 vf_pos;
|
| 83 |
+
void lineSegment(int aidx, int bidx) {
|
| 84 |
+
vec2 a = gl_in[aidx].gl_Position.xy;
|
| 85 |
+
vec2 b = gl_in[bidx].gl_Position.xy;
|
| 86 |
+
vec3 aCol = vg_pos[aidx];
|
| 87 |
+
vec3 bCol = vg_pos[bidx];
|
| 88 |
+
|
| 89 |
+
vec2 dir = normalize((b - a) * u_resolution);
|
| 90 |
+
vec2 offset = vec2(-dir.y, dir.x) * u_dilation / u_resolution;
|
| 91 |
+
|
| 92 |
+
gl_Position = vec4(a + offset, 0.0, 1.0);
|
| 93 |
+
vf_pos = aCol;
|
| 94 |
+
EmitVertex();
|
| 95 |
+
gl_Position = vec4(a - offset, 0.0, 1.0);
|
| 96 |
+
vf_pos = aCol;
|
| 97 |
+
EmitVertex();
|
| 98 |
+
gl_Position = vec4(b + offset, 0.0, 1.0);
|
| 99 |
+
vf_pos = bCol;
|
| 100 |
+
EmitVertex();
|
| 101 |
+
gl_Position = vec4(b - offset, 0.0, 1.0);
|
| 102 |
+
vf_pos = bCol;
|
| 103 |
+
EmitVertex();
|
| 104 |
+
}
|
| 105 |
+
void main() {
|
| 106 |
+
lineSegment(0, 1);
|
| 107 |
+
lineSegment(1, 2);
|
| 108 |
+
lineSegment(2, 0);
|
| 109 |
+
EndPrimitive();
|
| 110 |
+
}
|
| 111 |
+
""",
|
| 112 |
+
fragment_shader="""
|
| 113 |
+
#version 330
|
| 114 |
+
in vec3 vf_pos;
|
| 115 |
+
out vec4 o_col;
|
| 116 |
+
void main() {
|
| 117 |
+
o_col = vec4(vf_pos, 1.0);
|
| 118 |
+
}
|
| 119 |
+
""",
|
| 120 |
+
)
|
| 121 |
+
uvs = atlas_uvs.flatten().astype("f4")
|
| 122 |
+
pos = mesh.vertices[atlas_vmapping].flatten().astype("f4")
|
| 123 |
+
indices = atlas_indices.flatten().astype("i4")
|
| 124 |
+
vbo_uvs = ctx.buffer(uvs)
|
| 125 |
+
vbo_pos = ctx.buffer(pos)
|
| 126 |
+
ibo = ctx.buffer(indices)
|
| 127 |
+
vao_content = [
|
| 128 |
+
vbo_uvs.bind("in_uv", layout="2f"),
|
| 129 |
+
vbo_pos.bind("in_pos", layout="3f"),
|
| 130 |
+
]
|
| 131 |
+
basic_vao = ctx.vertex_array(basic_prog, vao_content, ibo)
|
| 132 |
+
gs_vao = ctx.vertex_array(gs_prog, vao_content, ibo)
|
| 133 |
+
fbo = ctx.framebuffer(
|
| 134 |
+
color_attachments=[
|
| 135 |
+
ctx.texture((texture_resolution, texture_resolution), 4, dtype="f4")
|
| 136 |
+
]
|
| 137 |
+
)
|
| 138 |
+
fbo.use()
|
| 139 |
+
fbo.clear(0.0, 0.0, 0.0, 0.0)
|
| 140 |
+
gs_prog["u_resolution"].value = texture_resolution
|
| 141 |
+
gs_prog["u_dilation"].value = texture_padding
|
| 142 |
+
gs_vao.render()
|
| 143 |
+
basic_vao.render()
|
| 144 |
+
|
| 145 |
+
fbo_bytes = fbo.color_attachments[0].read()
|
| 146 |
+
fbo_np = np.frombuffer(fbo_bytes, dtype="f4").reshape(
|
| 147 |
+
texture_resolution, texture_resolution, 4
|
| 148 |
+
)
|
| 149 |
+
return fbo_np
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def positions_to_colors(model, scene_code, positions_texture, texture_resolution):
|
| 153 |
+
positions = torch.tensor(positions_texture.reshape(-1, 4)[:, :-1])
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
queried_grid = model.renderer.query_triplane(
|
| 156 |
+
model.decoder,
|
| 157 |
+
positions,
|
| 158 |
+
scene_code,
|
| 159 |
+
)
|
| 160 |
+
rgb_f = queried_grid["color"].numpy().reshape(-1, 3)
|
| 161 |
+
rgba_f = np.insert(rgb_f, 3, positions_texture.reshape(-1, 4)[:, -1], axis=1)
|
| 162 |
+
rgba_f[rgba_f[:, -1] == 0.0] = [0, 0, 0, 0]
|
| 163 |
+
return rgba_f.reshape(texture_resolution, texture_resolution, 4)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def bake_texture(mesh, model, scene_code, texture_resolution):
|
| 167 |
+
if not _HAS_XATLAS:
|
| 168 |
+
raise ImportError(
|
| 169 |
+
"xatlas and moderngl are required for texture baking. Please install them with:\n"
|
| 170 |
+
" pip install xatlas==0.0.9 moderngl==5.10.0\n"
|
| 171 |
+
"Note: These require Microsoft Visual C++ Build Tools to compile."
|
| 172 |
+
)
|
| 173 |
+
texture_padding = round(max(2, texture_resolution / 256))
|
| 174 |
+
atlas = make_atlas(mesh, texture_resolution, texture_padding)
|
| 175 |
+
positions_texture = rasterize_position_atlas(
|
| 176 |
+
mesh,
|
| 177 |
+
atlas["vmapping"],
|
| 178 |
+
atlas["indices"],
|
| 179 |
+
atlas["uvs"],
|
| 180 |
+
texture_resolution,
|
| 181 |
+
texture_padding,
|
| 182 |
+
)
|
| 183 |
+
colors_texture = positions_to_colors(
|
| 184 |
+
model, scene_code, positions_texture, texture_resolution
|
| 185 |
+
)
|
| 186 |
+
return {
|
| 187 |
+
"vmapping": atlas["vmapping"],
|
| 188 |
+
"indices": atlas["indices"],
|
| 189 |
+
"uvs": atlas["uvs"],
|
| 190 |
+
"colors": colors_texture,
|
| 191 |
+
}
|
tsr/models/__pycache__/isosurface.cpython-313.pyc
ADDED
|
Binary file (4.26 kB). View file
|
|
|
tsr/models/__pycache__/nerf_renderer.cpython-313.pyc
ADDED
|
Binary file (8.46 kB). View file
|
|
|
tsr/models/__pycache__/network_utils.cpython-313.pyc
ADDED
|
Binary file (5.88 kB). View file
|
|
|
tsr/models/isosurface.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from torchmcubes import marching_cubes
|
| 9 |
+
_HAS_TORCHMCUBES = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
_HAS_TORCHMCUBES = False
|
| 12 |
+
marching_cubes = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class IsosurfaceHelper(nn.Module):
|
| 16 |
+
points_range: Tuple[float, float] = (0, 1)
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
| 20 |
+
raise NotImplementedError
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MarchingCubeHelper(IsosurfaceHelper):
|
| 24 |
+
def __init__(self, resolution: int) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
if not _HAS_TORCHMCUBES:
|
| 27 |
+
raise ImportError(
|
| 28 |
+
"torchmcubes is required for mesh extraction. Please install it with:\n"
|
| 29 |
+
" pip install git+https://github.com/tatsy/torchmcubes.git\n"
|
| 30 |
+
"Note: This requires Microsoft Visual C++ Build Tools to compile."
|
| 31 |
+
)
|
| 32 |
+
self.resolution = resolution
|
| 33 |
+
self.mc_func: Callable = marching_cubes
|
| 34 |
+
self._grid_vertices: Optional[torch.FloatTensor] = None
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def grid_vertices(self) -> torch.FloatTensor:
|
| 38 |
+
if self._grid_vertices is None:
|
| 39 |
+
# keep the vertices on CPU so that we can support very large resolution
|
| 40 |
+
x, y, z = (
|
| 41 |
+
torch.linspace(*self.points_range, self.resolution),
|
| 42 |
+
torch.linspace(*self.points_range, self.resolution),
|
| 43 |
+
torch.linspace(*self.points_range, self.resolution),
|
| 44 |
+
)
|
| 45 |
+
x, y, z = torch.meshgrid(x, y, z, indexing="ij")
|
| 46 |
+
verts = torch.cat(
|
| 47 |
+
[x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
|
| 48 |
+
).reshape(-1, 3)
|
| 49 |
+
self._grid_vertices = verts
|
| 50 |
+
return self._grid_vertices
|
| 51 |
+
|
| 52 |
+
def forward(
|
| 53 |
+
self,
|
| 54 |
+
level: torch.FloatTensor,
|
| 55 |
+
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
| 56 |
+
level = -level.view(self.resolution, self.resolution, self.resolution)
|
| 57 |
+
try:
|
| 58 |
+
v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
|
| 59 |
+
except AttributeError:
|
| 60 |
+
print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
|
| 61 |
+
v_pos, t_pos_idx = self.mc_func(level.detach().cpu(), 0.0)
|
| 62 |
+
v_pos = v_pos[..., [2, 1, 0]]
|
| 63 |
+
v_pos = v_pos / (self.resolution - 1.0)
|
| 64 |
+
return v_pos.to(level.device), t_pos_idx.to(level.device)
|
tsr/models/nerf_renderer.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange, reduce
|
| 7 |
+
|
| 8 |
+
from ..utils import (
|
| 9 |
+
BaseModule,
|
| 10 |
+
chunk_batch,
|
| 11 |
+
get_activation,
|
| 12 |
+
rays_intersect_bbox,
|
| 13 |
+
scale_tensor,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TriplaneNeRFRenderer(BaseModule):
|
| 18 |
+
@dataclass
|
| 19 |
+
class Config(BaseModule.Config):
|
| 20 |
+
radius: float
|
| 21 |
+
|
| 22 |
+
feature_reduction: str = "concat"
|
| 23 |
+
density_activation: str = "trunc_exp"
|
| 24 |
+
density_bias: float = -1.0
|
| 25 |
+
color_activation: str = "sigmoid"
|
| 26 |
+
num_samples_per_ray: int = 128
|
| 27 |
+
randomized: bool = False
|
| 28 |
+
|
| 29 |
+
cfg: Config
|
| 30 |
+
|
| 31 |
+
def configure(self) -> None:
|
| 32 |
+
assert self.cfg.feature_reduction in ["concat", "mean"]
|
| 33 |
+
self.chunk_size = 0
|
| 34 |
+
|
| 35 |
+
def set_chunk_size(self, chunk_size: int):
|
| 36 |
+
assert (
|
| 37 |
+
chunk_size >= 0
|
| 38 |
+
), "chunk_size must be a non-negative integer (0 for no chunking)."
|
| 39 |
+
self.chunk_size = chunk_size
|
| 40 |
+
|
| 41 |
+
def query_triplane(
|
| 42 |
+
self,
|
| 43 |
+
decoder: torch.nn.Module,
|
| 44 |
+
positions: torch.Tensor,
|
| 45 |
+
triplane: torch.Tensor,
|
| 46 |
+
) -> Dict[str, torch.Tensor]:
|
| 47 |
+
input_shape = positions.shape[:-1]
|
| 48 |
+
positions = positions.view(-1, 3)
|
| 49 |
+
|
| 50 |
+
# positions in (-radius, radius)
|
| 51 |
+
# normalized to (-1, 1) for grid sample
|
| 52 |
+
positions = scale_tensor(
|
| 53 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def _query_chunk(x):
|
| 57 |
+
indices2D: torch.Tensor = torch.stack(
|
| 58 |
+
(x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
|
| 59 |
+
dim=-3,
|
| 60 |
+
)
|
| 61 |
+
out: torch.Tensor = F.grid_sample(
|
| 62 |
+
rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
|
| 63 |
+
rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
|
| 64 |
+
align_corners=False,
|
| 65 |
+
mode="bilinear",
|
| 66 |
+
)
|
| 67 |
+
if self.cfg.feature_reduction == "concat":
|
| 68 |
+
out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
|
| 69 |
+
elif self.cfg.feature_reduction == "mean":
|
| 70 |
+
out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
|
| 71 |
+
else:
|
| 72 |
+
raise NotImplementedError
|
| 73 |
+
|
| 74 |
+
net_out: Dict[str, torch.Tensor] = decoder(out)
|
| 75 |
+
return net_out
|
| 76 |
+
|
| 77 |
+
if self.chunk_size > 0:
|
| 78 |
+
net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
|
| 79 |
+
else:
|
| 80 |
+
net_out = _query_chunk(positions)
|
| 81 |
+
|
| 82 |
+
net_out["density_act"] = get_activation(self.cfg.density_activation)(
|
| 83 |
+
net_out["density"] + self.cfg.density_bias
|
| 84 |
+
)
|
| 85 |
+
net_out["color"] = get_activation(self.cfg.color_activation)(
|
| 86 |
+
net_out["features"]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
|
| 90 |
+
|
| 91 |
+
return net_out
|
| 92 |
+
|
| 93 |
+
def _forward(
|
| 94 |
+
self,
|
| 95 |
+
decoder: torch.nn.Module,
|
| 96 |
+
triplane: torch.Tensor,
|
| 97 |
+
rays_o: torch.Tensor,
|
| 98 |
+
rays_d: torch.Tensor,
|
| 99 |
+
**kwargs,
|
| 100 |
+
):
|
| 101 |
+
rays_shape = rays_o.shape[:-1]
|
| 102 |
+
rays_o = rays_o.view(-1, 3)
|
| 103 |
+
rays_d = rays_d.view(-1, 3)
|
| 104 |
+
n_rays = rays_o.shape[0]
|
| 105 |
+
|
| 106 |
+
t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
|
| 107 |
+
t_near, t_far = t_near[rays_valid], t_far[rays_valid]
|
| 108 |
+
|
| 109 |
+
t_vals = torch.linspace(
|
| 110 |
+
0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
|
| 111 |
+
)
|
| 112 |
+
t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
|
| 113 |
+
z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None] # (N_rays, N_samples)
|
| 114 |
+
|
| 115 |
+
xyz = (
|
| 116 |
+
rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
|
| 117 |
+
) # (N_rays, N_sample, 3)
|
| 118 |
+
|
| 119 |
+
mlp_out = self.query_triplane(
|
| 120 |
+
decoder=decoder,
|
| 121 |
+
positions=xyz,
|
| 122 |
+
triplane=triplane,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
eps = 1e-10
|
| 126 |
+
# deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
|
| 127 |
+
deltas = t_vals[1:] - t_vals[:-1] # (N_rays, N_samples)
|
| 128 |
+
alpha = 1 - torch.exp(
|
| 129 |
+
-deltas * mlp_out["density_act"][..., 0]
|
| 130 |
+
) # (N_rays, N_samples)
|
| 131 |
+
accum_prod = torch.cat(
|
| 132 |
+
[
|
| 133 |
+
torch.ones_like(alpha[:, :1]),
|
| 134 |
+
torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
|
| 135 |
+
],
|
| 136 |
+
dim=-1,
|
| 137 |
+
)
|
| 138 |
+
weights = alpha * accum_prod # (N_rays, N_samples)
|
| 139 |
+
comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2) # (N_rays, 3)
|
| 140 |
+
opacity_ = weights.sum(dim=-1) # (N_rays)
|
| 141 |
+
|
| 142 |
+
comp_rgb = torch.zeros(
|
| 143 |
+
n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
|
| 144 |
+
)
|
| 145 |
+
opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
|
| 146 |
+
comp_rgb[rays_valid] = comp_rgb_
|
| 147 |
+
opacity[rays_valid] = opacity_
|
| 148 |
+
|
| 149 |
+
comp_rgb += 1 - opacity[..., None]
|
| 150 |
+
comp_rgb = comp_rgb.view(*rays_shape, 3)
|
| 151 |
+
|
| 152 |
+
return comp_rgb
|
| 153 |
+
|
| 154 |
+
def forward(
|
| 155 |
+
self,
|
| 156 |
+
decoder: torch.nn.Module,
|
| 157 |
+
triplane: torch.Tensor,
|
| 158 |
+
rays_o: torch.Tensor,
|
| 159 |
+
rays_d: torch.Tensor,
|
| 160 |
+
) -> Dict[str, torch.Tensor]:
|
| 161 |
+
if triplane.ndim == 4:
|
| 162 |
+
comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
|
| 163 |
+
else:
|
| 164 |
+
comp_rgb = torch.stack(
|
| 165 |
+
[
|
| 166 |
+
self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
|
| 167 |
+
for i in range(triplane.shape[0])
|
| 168 |
+
],
|
| 169 |
+
dim=0,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return comp_rgb
|
| 173 |
+
|
| 174 |
+
def train(self, mode=True):
|
| 175 |
+
self.randomized = mode and self.cfg.randomized
|
| 176 |
+
return super().train(mode=mode)
|
| 177 |
+
|
| 178 |
+
def eval(self):
|
| 179 |
+
self.randomized = False
|
| 180 |
+
return super().eval()
|
tsr/models/network_utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
from ..utils import BaseModule
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TriplaneUpsampleNetwork(BaseModule):
|
| 12 |
+
@dataclass
|
| 13 |
+
class Config(BaseModule.Config):
|
| 14 |
+
in_channels: int
|
| 15 |
+
out_channels: int
|
| 16 |
+
|
| 17 |
+
cfg: Config
|
| 18 |
+
|
| 19 |
+
def configure(self) -> None:
|
| 20 |
+
self.upsample = nn.ConvTranspose2d(
|
| 21 |
+
self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
triplanes_up = rearrange(
|
| 26 |
+
self.upsample(
|
| 27 |
+
rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
|
| 28 |
+
),
|
| 29 |
+
"(B Np) Co Hp Wp -> B Np Co Hp Wp",
|
| 30 |
+
Np=3,
|
| 31 |
+
)
|
| 32 |
+
return triplanes_up
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class NeRFMLP(BaseModule):
|
| 36 |
+
@dataclass
|
| 37 |
+
class Config(BaseModule.Config):
|
| 38 |
+
in_channels: int
|
| 39 |
+
n_neurons: int
|
| 40 |
+
n_hidden_layers: int
|
| 41 |
+
activation: str = "relu"
|
| 42 |
+
bias: bool = True
|
| 43 |
+
weight_init: Optional[str] = "kaiming_uniform"
|
| 44 |
+
bias_init: Optional[str] = None
|
| 45 |
+
|
| 46 |
+
cfg: Config
|
| 47 |
+
|
| 48 |
+
def configure(self) -> None:
|
| 49 |
+
layers = [
|
| 50 |
+
self.make_linear(
|
| 51 |
+
self.cfg.in_channels,
|
| 52 |
+
self.cfg.n_neurons,
|
| 53 |
+
bias=self.cfg.bias,
|
| 54 |
+
weight_init=self.cfg.weight_init,
|
| 55 |
+
bias_init=self.cfg.bias_init,
|
| 56 |
+
),
|
| 57 |
+
self.make_activation(self.cfg.activation),
|
| 58 |
+
]
|
| 59 |
+
for i in range(self.cfg.n_hidden_layers - 1):
|
| 60 |
+
layers += [
|
| 61 |
+
self.make_linear(
|
| 62 |
+
self.cfg.n_neurons,
|
| 63 |
+
self.cfg.n_neurons,
|
| 64 |
+
bias=self.cfg.bias,
|
| 65 |
+
weight_init=self.cfg.weight_init,
|
| 66 |
+
bias_init=self.cfg.bias_init,
|
| 67 |
+
),
|
| 68 |
+
self.make_activation(self.cfg.activation),
|
| 69 |
+
]
|
| 70 |
+
layers += [
|
| 71 |
+
self.make_linear(
|
| 72 |
+
self.cfg.n_neurons,
|
| 73 |
+
4, # density 1 + features 3
|
| 74 |
+
bias=self.cfg.bias,
|
| 75 |
+
weight_init=self.cfg.weight_init,
|
| 76 |
+
bias_init=self.cfg.bias_init,
|
| 77 |
+
)
|
| 78 |
+
]
|
| 79 |
+
self.layers = nn.Sequential(*layers)
|
| 80 |
+
|
| 81 |
+
def make_linear(
|
| 82 |
+
self,
|
| 83 |
+
dim_in,
|
| 84 |
+
dim_out,
|
| 85 |
+
bias=True,
|
| 86 |
+
weight_init=None,
|
| 87 |
+
bias_init=None,
|
| 88 |
+
):
|
| 89 |
+
layer = nn.Linear(dim_in, dim_out, bias=bias)
|
| 90 |
+
|
| 91 |
+
if weight_init is None:
|
| 92 |
+
pass
|
| 93 |
+
elif weight_init == "kaiming_uniform":
|
| 94 |
+
torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
|
| 95 |
+
else:
|
| 96 |
+
raise NotImplementedError
|
| 97 |
+
|
| 98 |
+
if bias:
|
| 99 |
+
if bias_init is None:
|
| 100 |
+
pass
|
| 101 |
+
elif bias_init == "zero":
|
| 102 |
+
torch.nn.init.zeros_(layer.bias)
|
| 103 |
+
else:
|
| 104 |
+
raise NotImplementedError
|
| 105 |
+
|
| 106 |
+
return layer
|
| 107 |
+
|
| 108 |
+
def make_activation(self, activation):
|
| 109 |
+
if activation == "relu":
|
| 110 |
+
return nn.ReLU(inplace=True)
|
| 111 |
+
elif activation == "silu":
|
| 112 |
+
return nn.SiLU(inplace=True)
|
| 113 |
+
else:
|
| 114 |
+
raise NotImplementedError
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
inp_shape = x.shape[:-1]
|
| 118 |
+
x = x.reshape(-1, x.shape[-1])
|
| 119 |
+
|
| 120 |
+
features = self.layers(x)
|
| 121 |
+
features = features.reshape(*inp_shape, -1)
|
| 122 |
+
out = {"density": features[..., 0:1], "features": features[..., 1:4]}
|
| 123 |
+
|
| 124 |
+
return out
|
tsr/models/tokenizers/__pycache__/image.cpython-313.pyc
ADDED
|
Binary file (3.59 kB). View file
|
|
|
tsr/models/tokenizers/__pycache__/triplane.cpython-313.pyc
ADDED
|
Binary file (2.77 kB). View file
|
|
|
tsr/models/tokenizers/image.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from transformers.models.vit.modeling_vit import ViTModel
|
| 8 |
+
|
| 9 |
+
from ...utils import BaseModule
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOSingleImageTokenizer(BaseModule):
|
| 13 |
+
@dataclass
|
| 14 |
+
class Config(BaseModule.Config):
|
| 15 |
+
pretrained_model_name_or_path: str = "facebook/dino-vitb16"
|
| 16 |
+
enable_gradient_checkpointing: bool = False
|
| 17 |
+
|
| 18 |
+
cfg: Config
|
| 19 |
+
|
| 20 |
+
def configure(self) -> None:
|
| 21 |
+
self.model: ViTModel = ViTModel(
|
| 22 |
+
ViTModel.config_class.from_pretrained(
|
| 23 |
+
hf_hub_download(
|
| 24 |
+
repo_id=self.cfg.pretrained_model_name_or_path,
|
| 25 |
+
filename="config.json",
|
| 26 |
+
)
|
| 27 |
+
)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
if self.cfg.enable_gradient_checkpointing:
|
| 31 |
+
self.model.encoder.gradient_checkpointing = True
|
| 32 |
+
|
| 33 |
+
self.register_buffer(
|
| 34 |
+
"image_mean",
|
| 35 |
+
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
|
| 36 |
+
persistent=False,
|
| 37 |
+
)
|
| 38 |
+
self.register_buffer(
|
| 39 |
+
"image_std",
|
| 40 |
+
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
|
| 41 |
+
persistent=False,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
| 45 |
+
packed = False
|
| 46 |
+
if images.ndim == 4:
|
| 47 |
+
packed = True
|
| 48 |
+
images = images.unsqueeze(1)
|
| 49 |
+
|
| 50 |
+
batch_size, n_input_views = images.shape[:2]
|
| 51 |
+
images = (images - self.image_mean) / self.image_std
|
| 52 |
+
out = self.model(
|
| 53 |
+
rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
|
| 54 |
+
)
|
| 55 |
+
local_features, global_features = out.last_hidden_state, out.pooler_output
|
| 56 |
+
local_features = local_features.permute(0, 2, 1)
|
| 57 |
+
local_features = rearrange(
|
| 58 |
+
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
|
| 59 |
+
)
|
| 60 |
+
if packed:
|
| 61 |
+
local_features = local_features.squeeze(1)
|
| 62 |
+
|
| 63 |
+
return local_features
|
| 64 |
+
|
| 65 |
+
def detokenize(self, *args, **kwargs):
|
| 66 |
+
raise NotImplementedError
|
tsr/models/tokenizers/triplane.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from einops import rearrange, repeat
|
| 7 |
+
|
| 8 |
+
from ...utils import BaseModule
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Triplane1DTokenizer(BaseModule):
|
| 12 |
+
@dataclass
|
| 13 |
+
class Config(BaseModule.Config):
|
| 14 |
+
plane_size: int
|
| 15 |
+
num_channels: int
|
| 16 |
+
|
| 17 |
+
cfg: Config
|
| 18 |
+
|
| 19 |
+
def configure(self) -> None:
|
| 20 |
+
self.embeddings = nn.Parameter(
|
| 21 |
+
torch.randn(
|
| 22 |
+
(3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
|
| 23 |
+
dtype=torch.float32,
|
| 24 |
+
)
|
| 25 |
+
* 1
|
| 26 |
+
/ math.sqrt(self.cfg.num_channels)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(self, batch_size: int) -> torch.Tensor:
|
| 30 |
+
return rearrange(
|
| 31 |
+
repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
|
| 32 |
+
"B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
batch_size, Ct, Nt = tokens.shape
|
| 37 |
+
assert Nt == self.cfg.plane_size**2 * 3
|
| 38 |
+
assert Ct == self.cfg.num_channels
|
| 39 |
+
return rearrange(
|
| 40 |
+
tokens,
|
| 41 |
+
"B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
|
| 42 |
+
Np=3,
|
| 43 |
+
Hp=self.cfg.plane_size,
|
| 44 |
+
Wp=self.cfg.plane_size,
|
| 45 |
+
)
|
tsr/models/transformer/__pycache__/attention.cpython-313.pyc
ADDED
|
Binary file (23.5 kB). View file
|
|
|
tsr/models/transformer/__pycache__/basic_transformer_block.cpython-313.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
tsr/models/transformer/__pycache__/transformer_1d.cpython-313.pyc
ADDED
|
Binary file (7.52 kB). View file
|
|
|
tsr/models/transformer/attention.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# --------
|
| 16 |
+
#
|
| 17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
| 18 |
+
#
|
| 19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
| 20 |
+
#
|
| 21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 23 |
+
# in the Software without restriction, including without limitation the rights
|
| 24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 26 |
+
# furnished to do so, subject to the following conditions:
|
| 27 |
+
#
|
| 28 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 29 |
+
# copies or substantial portions of the Software.
|
| 30 |
+
#
|
| 31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 37 |
+
# SOFTWARE.
|
| 38 |
+
|
| 39 |
+
from typing import Optional
|
| 40 |
+
|
| 41 |
+
import torch
|
| 42 |
+
import torch.nn.functional as F
|
| 43 |
+
from torch import nn
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Attention(nn.Module):
|
| 47 |
+
r"""
|
| 48 |
+
A cross attention layer.
|
| 49 |
+
|
| 50 |
+
Parameters:
|
| 51 |
+
query_dim (`int`):
|
| 52 |
+
The number of channels in the query.
|
| 53 |
+
cross_attention_dim (`int`, *optional*):
|
| 54 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
| 55 |
+
heads (`int`, *optional*, defaults to 8):
|
| 56 |
+
The number of heads to use for multi-head attention.
|
| 57 |
+
dim_head (`int`, *optional*, defaults to 64):
|
| 58 |
+
The number of channels in each head.
|
| 59 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 60 |
+
The dropout probability to use.
|
| 61 |
+
bias (`bool`, *optional*, defaults to False):
|
| 62 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
| 63 |
+
upcast_attention (`bool`, *optional*, defaults to False):
|
| 64 |
+
Set to `True` to upcast the attention computation to `float32`.
|
| 65 |
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
| 66 |
+
Set to `True` to upcast the softmax computation to `float32`.
|
| 67 |
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
| 68 |
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
| 69 |
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
| 70 |
+
The number of groups to use for the group norm in the cross attention.
|
| 71 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
| 72 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
| 73 |
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
| 74 |
+
The number of groups to use for the group norm in the attention.
|
| 75 |
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
| 76 |
+
The number of channels to use for the spatial normalization.
|
| 77 |
+
out_bias (`bool`, *optional*, defaults to `True`):
|
| 78 |
+
Set to `True` to use a bias in the output linear layer.
|
| 79 |
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
| 80 |
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
| 81 |
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
| 82 |
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
| 83 |
+
`added_kv_proj_dim` is not `None`.
|
| 84 |
+
eps (`float`, *optional*, defaults to 1e-5):
|
| 85 |
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
| 86 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
| 87 |
+
A factor to rescale the output by dividing it with this value.
|
| 88 |
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
| 89 |
+
Set to `True` to add the residual connection to the output.
|
| 90 |
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
| 91 |
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
| 92 |
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
| 93 |
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
| 94 |
+
`AttnProcessor` otherwise.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
query_dim: int,
|
| 100 |
+
cross_attention_dim: Optional[int] = None,
|
| 101 |
+
heads: int = 8,
|
| 102 |
+
dim_head: int = 64,
|
| 103 |
+
dropout: float = 0.0,
|
| 104 |
+
bias: bool = False,
|
| 105 |
+
upcast_attention: bool = False,
|
| 106 |
+
upcast_softmax: bool = False,
|
| 107 |
+
cross_attention_norm: Optional[str] = None,
|
| 108 |
+
cross_attention_norm_num_groups: int = 32,
|
| 109 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 110 |
+
norm_num_groups: Optional[int] = None,
|
| 111 |
+
out_bias: bool = True,
|
| 112 |
+
scale_qk: bool = True,
|
| 113 |
+
only_cross_attention: bool = False,
|
| 114 |
+
eps: float = 1e-5,
|
| 115 |
+
rescale_output_factor: float = 1.0,
|
| 116 |
+
residual_connection: bool = False,
|
| 117 |
+
_from_deprecated_attn_block: bool = False,
|
| 118 |
+
processor: Optional["AttnProcessor"] = None,
|
| 119 |
+
out_dim: int = None,
|
| 120 |
+
):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 123 |
+
self.query_dim = query_dim
|
| 124 |
+
self.cross_attention_dim = (
|
| 125 |
+
cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 126 |
+
)
|
| 127 |
+
self.upcast_attention = upcast_attention
|
| 128 |
+
self.upcast_softmax = upcast_softmax
|
| 129 |
+
self.rescale_output_factor = rescale_output_factor
|
| 130 |
+
self.residual_connection = residual_connection
|
| 131 |
+
self.dropout = dropout
|
| 132 |
+
self.fused_projections = False
|
| 133 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 134 |
+
|
| 135 |
+
# we make use of this private variable to know whether this class is loaded
|
| 136 |
+
# with an deprecated state dict so that we can convert it on the fly
|
| 137 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
| 138 |
+
|
| 139 |
+
self.scale_qk = scale_qk
|
| 140 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
| 141 |
+
|
| 142 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 143 |
+
# for slice_size > 0 the attention score computation
|
| 144 |
+
# is split across the batch axis to save memory
|
| 145 |
+
# You can set slice_size with `set_attention_slice`
|
| 146 |
+
self.sliceable_head_dim = heads
|
| 147 |
+
|
| 148 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 149 |
+
self.only_cross_attention = only_cross_attention
|
| 150 |
+
|
| 151 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if norm_num_groups is not None:
|
| 157 |
+
self.group_norm = nn.GroupNorm(
|
| 158 |
+
num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
self.group_norm = None
|
| 162 |
+
|
| 163 |
+
self.spatial_norm = None
|
| 164 |
+
|
| 165 |
+
if cross_attention_norm is None:
|
| 166 |
+
self.norm_cross = None
|
| 167 |
+
elif cross_attention_norm == "layer_norm":
|
| 168 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
| 169 |
+
elif cross_attention_norm == "group_norm":
|
| 170 |
+
if self.added_kv_proj_dim is not None:
|
| 171 |
+
# The given `encoder_hidden_states` are initially of shape
|
| 172 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
| 173 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
| 174 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
| 175 |
+
# the number of channels for the group norm.
|
| 176 |
+
norm_cross_num_channels = added_kv_proj_dim
|
| 177 |
+
else:
|
| 178 |
+
norm_cross_num_channels = self.cross_attention_dim
|
| 179 |
+
|
| 180 |
+
self.norm_cross = nn.GroupNorm(
|
| 181 |
+
num_channels=norm_cross_num_channels,
|
| 182 |
+
num_groups=cross_attention_norm_num_groups,
|
| 183 |
+
eps=1e-5,
|
| 184 |
+
affine=True,
|
| 185 |
+
)
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
linear_cls = nn.Linear
|
| 192 |
+
|
| 193 |
+
self.linear_cls = linear_cls
|
| 194 |
+
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
| 195 |
+
|
| 196 |
+
if not self.only_cross_attention:
|
| 197 |
+
# only relevant for the `AddedKVProcessor` classes
|
| 198 |
+
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
| 199 |
+
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
| 200 |
+
else:
|
| 201 |
+
self.to_k = None
|
| 202 |
+
self.to_v = None
|
| 203 |
+
|
| 204 |
+
if self.added_kv_proj_dim is not None:
|
| 205 |
+
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
| 206 |
+
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
| 207 |
+
|
| 208 |
+
self.to_out = nn.ModuleList([])
|
| 209 |
+
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
|
| 210 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 211 |
+
|
| 212 |
+
# set attention processor
|
| 213 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
| 214 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
| 215 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
| 216 |
+
if processor is None:
|
| 217 |
+
processor = (
|
| 218 |
+
AttnProcessor2_0()
|
| 219 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
| 220 |
+
else AttnProcessor()
|
| 221 |
+
)
|
| 222 |
+
self.set_processor(processor)
|
| 223 |
+
|
| 224 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
| 225 |
+
self.processor = processor
|
| 226 |
+
|
| 227 |
+
def forward(
|
| 228 |
+
self,
|
| 229 |
+
hidden_states: torch.FloatTensor,
|
| 230 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 231 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 232 |
+
**cross_attention_kwargs,
|
| 233 |
+
) -> torch.Tensor:
|
| 234 |
+
r"""
|
| 235 |
+
The forward method of the `Attention` class.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
hidden_states (`torch.Tensor`):
|
| 239 |
+
The hidden states of the query.
|
| 240 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
| 241 |
+
The hidden states of the encoder.
|
| 242 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 243 |
+
The attention mask to use. If `None`, no mask is applied.
|
| 244 |
+
**cross_attention_kwargs:
|
| 245 |
+
Additional keyword arguments to pass along to the cross attention.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
`torch.Tensor`: The output of the attention layer.
|
| 249 |
+
"""
|
| 250 |
+
# The `Attention` class can call different attention processors / attention functions
|
| 251 |
+
# here we simply pass along all tensors to the selected processor class
|
| 252 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
| 253 |
+
return self.processor(
|
| 254 |
+
self,
|
| 255 |
+
hidden_states,
|
| 256 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 257 |
+
attention_mask=attention_mask,
|
| 258 |
+
**cross_attention_kwargs,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 262 |
+
r"""
|
| 263 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
| 264 |
+
is the number of heads initialized while constructing the `Attention` class.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
`torch.Tensor`: The reshaped tensor.
|
| 271 |
+
"""
|
| 272 |
+
head_size = self.heads
|
| 273 |
+
batch_size, seq_len, dim = tensor.shape
|
| 274 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 275 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
| 276 |
+
batch_size // head_size, seq_len, dim * head_size
|
| 277 |
+
)
|
| 278 |
+
return tensor
|
| 279 |
+
|
| 280 |
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
| 281 |
+
r"""
|
| 282 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
| 283 |
+
the number of heads initialized while constructing the `Attention` class.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
| 287 |
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
| 288 |
+
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
`torch.Tensor`: The reshaped tensor.
|
| 292 |
+
"""
|
| 293 |
+
head_size = self.heads
|
| 294 |
+
batch_size, seq_len, dim = tensor.shape
|
| 295 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
| 296 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
| 297 |
+
|
| 298 |
+
if out_dim == 3:
|
| 299 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
| 300 |
+
|
| 301 |
+
return tensor
|
| 302 |
+
|
| 303 |
+
def get_attention_scores(
|
| 304 |
+
self,
|
| 305 |
+
query: torch.Tensor,
|
| 306 |
+
key: torch.Tensor,
|
| 307 |
+
attention_mask: torch.Tensor = None,
|
| 308 |
+
) -> torch.Tensor:
|
| 309 |
+
r"""
|
| 310 |
+
Compute the attention scores.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
query (`torch.Tensor`): The query tensor.
|
| 314 |
+
key (`torch.Tensor`): The key tensor.
|
| 315 |
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
`torch.Tensor`: The attention probabilities/scores.
|
| 319 |
+
"""
|
| 320 |
+
dtype = query.dtype
|
| 321 |
+
if self.upcast_attention:
|
| 322 |
+
query = query.float()
|
| 323 |
+
key = key.float()
|
| 324 |
+
|
| 325 |
+
if attention_mask is None:
|
| 326 |
+
baddbmm_input = torch.empty(
|
| 327 |
+
query.shape[0],
|
| 328 |
+
query.shape[1],
|
| 329 |
+
key.shape[1],
|
| 330 |
+
dtype=query.dtype,
|
| 331 |
+
device=query.device,
|
| 332 |
+
)
|
| 333 |
+
beta = 0
|
| 334 |
+
else:
|
| 335 |
+
baddbmm_input = attention_mask
|
| 336 |
+
beta = 1
|
| 337 |
+
|
| 338 |
+
attention_scores = torch.baddbmm(
|
| 339 |
+
baddbmm_input,
|
| 340 |
+
query,
|
| 341 |
+
key.transpose(-1, -2),
|
| 342 |
+
beta=beta,
|
| 343 |
+
alpha=self.scale,
|
| 344 |
+
)
|
| 345 |
+
del baddbmm_input
|
| 346 |
+
|
| 347 |
+
if self.upcast_softmax:
|
| 348 |
+
attention_scores = attention_scores.float()
|
| 349 |
+
|
| 350 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 351 |
+
del attention_scores
|
| 352 |
+
|
| 353 |
+
attention_probs = attention_probs.to(dtype)
|
| 354 |
+
|
| 355 |
+
return attention_probs
|
| 356 |
+
|
| 357 |
+
def prepare_attention_mask(
|
| 358 |
+
self,
|
| 359 |
+
attention_mask: torch.Tensor,
|
| 360 |
+
target_length: int,
|
| 361 |
+
batch_size: int,
|
| 362 |
+
out_dim: int = 3,
|
| 363 |
+
) -> torch.Tensor:
|
| 364 |
+
r"""
|
| 365 |
+
Prepare the attention mask for the attention computation.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
attention_mask (`torch.Tensor`):
|
| 369 |
+
The attention mask to prepare.
|
| 370 |
+
target_length (`int`):
|
| 371 |
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
| 372 |
+
batch_size (`int`):
|
| 373 |
+
The batch size, which is used to repeat the attention mask.
|
| 374 |
+
out_dim (`int`, *optional*, defaults to `3`):
|
| 375 |
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
`torch.Tensor`: The prepared attention mask.
|
| 379 |
+
"""
|
| 380 |
+
head_size = self.heads
|
| 381 |
+
if attention_mask is None:
|
| 382 |
+
return attention_mask
|
| 383 |
+
|
| 384 |
+
current_length: int = attention_mask.shape[-1]
|
| 385 |
+
if current_length != target_length:
|
| 386 |
+
if attention_mask.device.type == "mps":
|
| 387 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
| 388 |
+
# Instead, we can manually construct the padding tensor.
|
| 389 |
+
padding_shape = (
|
| 390 |
+
attention_mask.shape[0],
|
| 391 |
+
attention_mask.shape[1],
|
| 392 |
+
target_length,
|
| 393 |
+
)
|
| 394 |
+
padding = torch.zeros(
|
| 395 |
+
padding_shape,
|
| 396 |
+
dtype=attention_mask.dtype,
|
| 397 |
+
device=attention_mask.device,
|
| 398 |
+
)
|
| 399 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
| 400 |
+
else:
|
| 401 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
| 402 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
| 403 |
+
# remaining_length: int = target_length - current_length
|
| 404 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
| 405 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 406 |
+
|
| 407 |
+
if out_dim == 3:
|
| 408 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
| 409 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
| 410 |
+
elif out_dim == 4:
|
| 411 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 412 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
| 413 |
+
|
| 414 |
+
return attention_mask
|
| 415 |
+
|
| 416 |
+
def norm_encoder_hidden_states(
|
| 417 |
+
self, encoder_hidden_states: torch.Tensor
|
| 418 |
+
) -> torch.Tensor:
|
| 419 |
+
r"""
|
| 420 |
+
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
| 421 |
+
`Attention` class.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
`torch.Tensor`: The normalized encoder hidden states.
|
| 428 |
+
"""
|
| 429 |
+
assert (
|
| 430 |
+
self.norm_cross is not None
|
| 431 |
+
), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
| 432 |
+
|
| 433 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
| 434 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 435 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
| 436 |
+
# Group norm norms along the channels dimension and expects
|
| 437 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
| 438 |
+
# to norm along the hidden dimension, so we need to move
|
| 439 |
+
# (batch_size, sequence_length, hidden_size) ->
|
| 440 |
+
# (batch_size, hidden_size, sequence_length)
|
| 441 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 442 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 443 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 444 |
+
else:
|
| 445 |
+
assert False
|
| 446 |
+
|
| 447 |
+
return encoder_hidden_states
|
| 448 |
+
|
| 449 |
+
@torch.no_grad()
|
| 450 |
+
def fuse_projections(self, fuse=True):
|
| 451 |
+
is_cross_attention = self.cross_attention_dim != self.query_dim
|
| 452 |
+
device = self.to_q.weight.data.device
|
| 453 |
+
dtype = self.to_q.weight.data.dtype
|
| 454 |
+
|
| 455 |
+
if not is_cross_attention:
|
| 456 |
+
# fetch weight matrices.
|
| 457 |
+
concatenated_weights = torch.cat(
|
| 458 |
+
[self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
|
| 459 |
+
)
|
| 460 |
+
in_features = concatenated_weights.shape[1]
|
| 461 |
+
out_features = concatenated_weights.shape[0]
|
| 462 |
+
|
| 463 |
+
# create a new single projection layer and copy over the weights.
|
| 464 |
+
self.to_qkv = self.linear_cls(
|
| 465 |
+
in_features, out_features, bias=False, device=device, dtype=dtype
|
| 466 |
+
)
|
| 467 |
+
self.to_qkv.weight.copy_(concatenated_weights)
|
| 468 |
+
|
| 469 |
+
else:
|
| 470 |
+
concatenated_weights = torch.cat(
|
| 471 |
+
[self.to_k.weight.data, self.to_v.weight.data]
|
| 472 |
+
)
|
| 473 |
+
in_features = concatenated_weights.shape[1]
|
| 474 |
+
out_features = concatenated_weights.shape[0]
|
| 475 |
+
|
| 476 |
+
self.to_kv = self.linear_cls(
|
| 477 |
+
in_features, out_features, bias=False, device=device, dtype=dtype
|
| 478 |
+
)
|
| 479 |
+
self.to_kv.weight.copy_(concatenated_weights)
|
| 480 |
+
|
| 481 |
+
self.fused_projections = fuse
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class AttnProcessor:
|
| 485 |
+
r"""
|
| 486 |
+
Default processor for performing attention-related computations.
|
| 487 |
+
"""
|
| 488 |
+
|
| 489 |
+
def __call__(
|
| 490 |
+
self,
|
| 491 |
+
attn: Attention,
|
| 492 |
+
hidden_states: torch.FloatTensor,
|
| 493 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 494 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 495 |
+
) -> torch.Tensor:
|
| 496 |
+
residual = hidden_states
|
| 497 |
+
|
| 498 |
+
input_ndim = hidden_states.ndim
|
| 499 |
+
|
| 500 |
+
if input_ndim == 4:
|
| 501 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 502 |
+
hidden_states = hidden_states.view(
|
| 503 |
+
batch_size, channel, height * width
|
| 504 |
+
).transpose(1, 2)
|
| 505 |
+
|
| 506 |
+
batch_size, sequence_length, _ = (
|
| 507 |
+
hidden_states.shape
|
| 508 |
+
if encoder_hidden_states is None
|
| 509 |
+
else encoder_hidden_states.shape
|
| 510 |
+
)
|
| 511 |
+
attention_mask = attn.prepare_attention_mask(
|
| 512 |
+
attention_mask, sequence_length, batch_size
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if attn.group_norm is not None:
|
| 516 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 517 |
+
1, 2
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
query = attn.to_q(hidden_states)
|
| 521 |
+
|
| 522 |
+
if encoder_hidden_states is None:
|
| 523 |
+
encoder_hidden_states = hidden_states
|
| 524 |
+
elif attn.norm_cross:
|
| 525 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 526 |
+
encoder_hidden_states
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
key = attn.to_k(encoder_hidden_states)
|
| 530 |
+
value = attn.to_v(encoder_hidden_states)
|
| 531 |
+
|
| 532 |
+
query = attn.head_to_batch_dim(query)
|
| 533 |
+
key = attn.head_to_batch_dim(key)
|
| 534 |
+
value = attn.head_to_batch_dim(value)
|
| 535 |
+
|
| 536 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 537 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 538 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 539 |
+
|
| 540 |
+
# linear proj
|
| 541 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 542 |
+
# dropout
|
| 543 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 544 |
+
|
| 545 |
+
if input_ndim == 4:
|
| 546 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 547 |
+
batch_size, channel, height, width
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
if attn.residual_connection:
|
| 551 |
+
hidden_states = hidden_states + residual
|
| 552 |
+
|
| 553 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 554 |
+
|
| 555 |
+
return hidden_states
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class AttnProcessor2_0:
|
| 559 |
+
r"""
|
| 560 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
| 561 |
+
"""
|
| 562 |
+
|
| 563 |
+
def __init__(self):
|
| 564 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 565 |
+
raise ImportError(
|
| 566 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
def __call__(
|
| 570 |
+
self,
|
| 571 |
+
attn: Attention,
|
| 572 |
+
hidden_states: torch.FloatTensor,
|
| 573 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 574 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 575 |
+
) -> torch.FloatTensor:
|
| 576 |
+
residual = hidden_states
|
| 577 |
+
|
| 578 |
+
input_ndim = hidden_states.ndim
|
| 579 |
+
|
| 580 |
+
if input_ndim == 4:
|
| 581 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 582 |
+
hidden_states = hidden_states.view(
|
| 583 |
+
batch_size, channel, height * width
|
| 584 |
+
).transpose(1, 2)
|
| 585 |
+
|
| 586 |
+
batch_size, sequence_length, _ = (
|
| 587 |
+
hidden_states.shape
|
| 588 |
+
if encoder_hidden_states is None
|
| 589 |
+
else encoder_hidden_states.shape
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
if attention_mask is not None:
|
| 593 |
+
attention_mask = attn.prepare_attention_mask(
|
| 594 |
+
attention_mask, sequence_length, batch_size
|
| 595 |
+
)
|
| 596 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 597 |
+
# (batch, heads, source_length, target_length)
|
| 598 |
+
attention_mask = attention_mask.view(
|
| 599 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
if attn.group_norm is not None:
|
| 603 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 604 |
+
1, 2
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
query = attn.to_q(hidden_states)
|
| 608 |
+
|
| 609 |
+
if encoder_hidden_states is None:
|
| 610 |
+
encoder_hidden_states = hidden_states
|
| 611 |
+
elif attn.norm_cross:
|
| 612 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 613 |
+
encoder_hidden_states
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
key = attn.to_k(encoder_hidden_states)
|
| 617 |
+
value = attn.to_v(encoder_hidden_states)
|
| 618 |
+
|
| 619 |
+
inner_dim = key.shape[-1]
|
| 620 |
+
head_dim = inner_dim // attn.heads
|
| 621 |
+
|
| 622 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 623 |
+
|
| 624 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 625 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 626 |
+
|
| 627 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 628 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 629 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 630 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
| 634 |
+
batch_size, -1, attn.heads * head_dim
|
| 635 |
+
)
|
| 636 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 637 |
+
|
| 638 |
+
# linear proj
|
| 639 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 640 |
+
# dropout
|
| 641 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 642 |
+
|
| 643 |
+
if input_ndim == 4:
|
| 644 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 645 |
+
batch_size, channel, height, width
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
if attn.residual_connection:
|
| 649 |
+
hidden_states = hidden_states + residual
|
| 650 |
+
|
| 651 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 652 |
+
|
| 653 |
+
return hidden_states
|
tsr/models/transformer/basic_transformer_block.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# --------
|
| 16 |
+
#
|
| 17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
| 18 |
+
#
|
| 19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
| 20 |
+
#
|
| 21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 23 |
+
# in the Software without restriction, including without limitation the rights
|
| 24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 26 |
+
# furnished to do so, subject to the following conditions:
|
| 27 |
+
#
|
| 28 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 29 |
+
# copies or substantial portions of the Software.
|
| 30 |
+
#
|
| 31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 37 |
+
# SOFTWARE.
|
| 38 |
+
|
| 39 |
+
from typing import Optional
|
| 40 |
+
|
| 41 |
+
import torch
|
| 42 |
+
import torch.nn.functional as F
|
| 43 |
+
from torch import nn
|
| 44 |
+
|
| 45 |
+
from .attention import Attention
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class BasicTransformerBlock(nn.Module):
|
| 49 |
+
r"""
|
| 50 |
+
A basic Transformer block.
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
dim (`int`): The number of channels in the input and output.
|
| 54 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 55 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 56 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 57 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 58 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 59 |
+
attention_bias (:
|
| 60 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
| 61 |
+
only_cross_attention (`bool`, *optional*):
|
| 62 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
| 63 |
+
double_self_attention (`bool`, *optional*):
|
| 64 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
| 65 |
+
upcast_attention (`bool`, *optional*):
|
| 66 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
| 67 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 69 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
| 70 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
| 71 |
+
final_dropout (`bool` *optional*, defaults to False):
|
| 72 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
dim: int,
|
| 78 |
+
num_attention_heads: int,
|
| 79 |
+
attention_head_dim: int,
|
| 80 |
+
dropout=0.0,
|
| 81 |
+
cross_attention_dim: Optional[int] = None,
|
| 82 |
+
activation_fn: str = "geglu",
|
| 83 |
+
attention_bias: bool = False,
|
| 84 |
+
only_cross_attention: bool = False,
|
| 85 |
+
double_self_attention: bool = False,
|
| 86 |
+
upcast_attention: bool = False,
|
| 87 |
+
norm_elementwise_affine: bool = True,
|
| 88 |
+
norm_type: str = "layer_norm",
|
| 89 |
+
final_dropout: bool = False,
|
| 90 |
+
):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.only_cross_attention = only_cross_attention
|
| 93 |
+
|
| 94 |
+
assert norm_type == "layer_norm"
|
| 95 |
+
|
| 96 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
| 97 |
+
# 1. Self-Attn
|
| 98 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 99 |
+
self.attn1 = Attention(
|
| 100 |
+
query_dim=dim,
|
| 101 |
+
heads=num_attention_heads,
|
| 102 |
+
dim_head=attention_head_dim,
|
| 103 |
+
dropout=dropout,
|
| 104 |
+
bias=attention_bias,
|
| 105 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
| 106 |
+
upcast_attention=upcast_attention,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 2. Cross-Attn
|
| 110 |
+
if cross_attention_dim is not None or double_self_attention:
|
| 111 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
| 112 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
| 113 |
+
# the second cross attention block.
|
| 114 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 115 |
+
|
| 116 |
+
self.attn2 = Attention(
|
| 117 |
+
query_dim=dim,
|
| 118 |
+
cross_attention_dim=(
|
| 119 |
+
cross_attention_dim if not double_self_attention else None
|
| 120 |
+
),
|
| 121 |
+
heads=num_attention_heads,
|
| 122 |
+
dim_head=attention_head_dim,
|
| 123 |
+
dropout=dropout,
|
| 124 |
+
bias=attention_bias,
|
| 125 |
+
upcast_attention=upcast_attention,
|
| 126 |
+
) # is self-attn if encoder_hidden_states is none
|
| 127 |
+
else:
|
| 128 |
+
self.norm2 = None
|
| 129 |
+
self.attn2 = None
|
| 130 |
+
|
| 131 |
+
# 3. Feed-forward
|
| 132 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 133 |
+
self.ff = FeedForward(
|
| 134 |
+
dim,
|
| 135 |
+
dropout=dropout,
|
| 136 |
+
activation_fn=activation_fn,
|
| 137 |
+
final_dropout=final_dropout,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# let chunk size default to None
|
| 141 |
+
self._chunk_size = None
|
| 142 |
+
self._chunk_dim = 0
|
| 143 |
+
|
| 144 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
| 145 |
+
# Sets chunk feed-forward
|
| 146 |
+
self._chunk_size = chunk_size
|
| 147 |
+
self._chunk_dim = dim
|
| 148 |
+
|
| 149 |
+
def forward(
|
| 150 |
+
self,
|
| 151 |
+
hidden_states: torch.FloatTensor,
|
| 152 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 153 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 154 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 155 |
+
) -> torch.FloatTensor:
|
| 156 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 157 |
+
# 0. Self-Attention
|
| 158 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 159 |
+
|
| 160 |
+
attn_output = self.attn1(
|
| 161 |
+
norm_hidden_states,
|
| 162 |
+
encoder_hidden_states=(
|
| 163 |
+
encoder_hidden_states if self.only_cross_attention else None
|
| 164 |
+
),
|
| 165 |
+
attention_mask=attention_mask,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
hidden_states = attn_output + hidden_states
|
| 169 |
+
|
| 170 |
+
# 3. Cross-Attention
|
| 171 |
+
if self.attn2 is not None:
|
| 172 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 173 |
+
|
| 174 |
+
attn_output = self.attn2(
|
| 175 |
+
norm_hidden_states,
|
| 176 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 177 |
+
attention_mask=encoder_attention_mask,
|
| 178 |
+
)
|
| 179 |
+
hidden_states = attn_output + hidden_states
|
| 180 |
+
|
| 181 |
+
# 4. Feed-forward
|
| 182 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 183 |
+
|
| 184 |
+
if self._chunk_size is not None:
|
| 185 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 186 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
| 192 |
+
ff_output = torch.cat(
|
| 193 |
+
[
|
| 194 |
+
self.ff(hid_slice)
|
| 195 |
+
for hid_slice in norm_hidden_states.chunk(
|
| 196 |
+
num_chunks, dim=self._chunk_dim
|
| 197 |
+
)
|
| 198 |
+
],
|
| 199 |
+
dim=self._chunk_dim,
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
ff_output = self.ff(norm_hidden_states)
|
| 203 |
+
|
| 204 |
+
hidden_states = ff_output + hidden_states
|
| 205 |
+
|
| 206 |
+
return hidden_states
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class FeedForward(nn.Module):
|
| 210 |
+
r"""
|
| 211 |
+
A feed-forward layer.
|
| 212 |
+
|
| 213 |
+
Parameters:
|
| 214 |
+
dim (`int`): The number of channels in the input.
|
| 215 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 216 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 217 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 218 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 219 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
def __init__(
|
| 223 |
+
self,
|
| 224 |
+
dim: int,
|
| 225 |
+
dim_out: Optional[int] = None,
|
| 226 |
+
mult: int = 4,
|
| 227 |
+
dropout: float = 0.0,
|
| 228 |
+
activation_fn: str = "geglu",
|
| 229 |
+
final_dropout: bool = False,
|
| 230 |
+
):
|
| 231 |
+
super().__init__()
|
| 232 |
+
inner_dim = int(dim * mult)
|
| 233 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 234 |
+
linear_cls = nn.Linear
|
| 235 |
+
|
| 236 |
+
if activation_fn == "gelu":
|
| 237 |
+
act_fn = GELU(dim, inner_dim)
|
| 238 |
+
if activation_fn == "gelu-approximate":
|
| 239 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
| 240 |
+
elif activation_fn == "geglu":
|
| 241 |
+
act_fn = GEGLU(dim, inner_dim)
|
| 242 |
+
elif activation_fn == "geglu-approximate":
|
| 243 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
| 244 |
+
|
| 245 |
+
self.net = nn.ModuleList([])
|
| 246 |
+
# project in
|
| 247 |
+
self.net.append(act_fn)
|
| 248 |
+
# project dropout
|
| 249 |
+
self.net.append(nn.Dropout(dropout))
|
| 250 |
+
# project out
|
| 251 |
+
self.net.append(linear_cls(inner_dim, dim_out))
|
| 252 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
| 253 |
+
if final_dropout:
|
| 254 |
+
self.net.append(nn.Dropout(dropout))
|
| 255 |
+
|
| 256 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 257 |
+
for module in self.net:
|
| 258 |
+
hidden_states = module(hidden_states)
|
| 259 |
+
return hidden_states
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class GELU(nn.Module):
|
| 263 |
+
r"""
|
| 264 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
| 265 |
+
|
| 266 |
+
Parameters:
|
| 267 |
+
dim_in (`int`): The number of channels in the input.
|
| 268 |
+
dim_out (`int`): The number of channels in the output.
|
| 269 |
+
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
| 275 |
+
self.approximate = approximate
|
| 276 |
+
|
| 277 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
| 278 |
+
if gate.device.type != "mps":
|
| 279 |
+
return F.gelu(gate, approximate=self.approximate)
|
| 280 |
+
# mps: gelu is not implemented for float16
|
| 281 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
|
| 282 |
+
dtype=gate.dtype
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
def forward(self, hidden_states):
|
| 286 |
+
hidden_states = self.proj(hidden_states)
|
| 287 |
+
hidden_states = self.gelu(hidden_states)
|
| 288 |
+
return hidden_states
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class GEGLU(nn.Module):
|
| 292 |
+
r"""
|
| 293 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
| 294 |
+
|
| 295 |
+
Parameters:
|
| 296 |
+
dim_in (`int`): The number of channels in the input.
|
| 297 |
+
dim_out (`int`): The number of channels in the output.
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
def __init__(self, dim_in: int, dim_out: int):
|
| 301 |
+
super().__init__()
|
| 302 |
+
linear_cls = nn.Linear
|
| 303 |
+
|
| 304 |
+
self.proj = linear_cls(dim_in, dim_out * 2)
|
| 305 |
+
|
| 306 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
if gate.device.type != "mps":
|
| 308 |
+
return F.gelu(gate)
|
| 309 |
+
# mps: gelu is not implemented for float16
|
| 310 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
| 311 |
+
|
| 312 |
+
def forward(self, hidden_states, scale: float = 1.0):
|
| 313 |
+
args = ()
|
| 314 |
+
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
| 315 |
+
return hidden_states * self.gelu(gate)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class ApproximateGELU(nn.Module):
|
| 319 |
+
r"""
|
| 320 |
+
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
|
| 321 |
+
https://arxiv.org/abs/1606.08415.
|
| 322 |
+
|
| 323 |
+
Parameters:
|
| 324 |
+
dim_in (`int`): The number of channels in the input.
|
| 325 |
+
dim_out (`int`): The number of channels in the output.
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
def __init__(self, dim_in: int, dim_out: int):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
| 331 |
+
|
| 332 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 333 |
+
x = self.proj(x)
|
| 334 |
+
return x * torch.sigmoid(1.702 * x)
|
tsr/models/transformer/transformer_1d.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# --------
|
| 16 |
+
#
|
| 17 |
+
# Modified 2024 by the Tripo AI and Stability AI Team.
|
| 18 |
+
#
|
| 19 |
+
# Copyright (c) 2024 Tripo AI & Stability AI
|
| 20 |
+
#
|
| 21 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 22 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 23 |
+
# in the Software without restriction, including without limitation the rights
|
| 24 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 25 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 26 |
+
# furnished to do so, subject to the following conditions:
|
| 27 |
+
#
|
| 28 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 29 |
+
# copies or substantial portions of the Software.
|
| 30 |
+
#
|
| 31 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 32 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 33 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 34 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 35 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 36 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 37 |
+
# SOFTWARE.
|
| 38 |
+
|
| 39 |
+
from dataclasses import dataclass
|
| 40 |
+
from typing import Optional
|
| 41 |
+
|
| 42 |
+
import torch
|
| 43 |
+
import torch.nn.functional as F
|
| 44 |
+
from torch import nn
|
| 45 |
+
|
| 46 |
+
from ...utils import BaseModule
|
| 47 |
+
from .basic_transformer_block import BasicTransformerBlock
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Transformer1D(BaseModule):
|
| 51 |
+
@dataclass
|
| 52 |
+
class Config(BaseModule.Config):
|
| 53 |
+
num_attention_heads: int = 16
|
| 54 |
+
attention_head_dim: int = 88
|
| 55 |
+
in_channels: Optional[int] = None
|
| 56 |
+
out_channels: Optional[int] = None
|
| 57 |
+
num_layers: int = 1
|
| 58 |
+
dropout: float = 0.0
|
| 59 |
+
norm_num_groups: int = 32
|
| 60 |
+
cross_attention_dim: Optional[int] = None
|
| 61 |
+
attention_bias: bool = False
|
| 62 |
+
activation_fn: str = "geglu"
|
| 63 |
+
only_cross_attention: bool = False
|
| 64 |
+
double_self_attention: bool = False
|
| 65 |
+
upcast_attention: bool = False
|
| 66 |
+
norm_type: str = "layer_norm"
|
| 67 |
+
norm_elementwise_affine: bool = True
|
| 68 |
+
gradient_checkpointing: bool = False
|
| 69 |
+
|
| 70 |
+
cfg: Config
|
| 71 |
+
|
| 72 |
+
def configure(self) -> None:
|
| 73 |
+
self.num_attention_heads = self.cfg.num_attention_heads
|
| 74 |
+
self.attention_head_dim = self.cfg.attention_head_dim
|
| 75 |
+
inner_dim = self.num_attention_heads * self.attention_head_dim
|
| 76 |
+
|
| 77 |
+
linear_cls = nn.Linear
|
| 78 |
+
|
| 79 |
+
# 2. Define input layers
|
| 80 |
+
self.in_channels = self.cfg.in_channels
|
| 81 |
+
|
| 82 |
+
self.norm = torch.nn.GroupNorm(
|
| 83 |
+
num_groups=self.cfg.norm_num_groups,
|
| 84 |
+
num_channels=self.cfg.in_channels,
|
| 85 |
+
eps=1e-6,
|
| 86 |
+
affine=True,
|
| 87 |
+
)
|
| 88 |
+
self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
|
| 89 |
+
|
| 90 |
+
# 3. Define transformers blocks
|
| 91 |
+
self.transformer_blocks = nn.ModuleList(
|
| 92 |
+
[
|
| 93 |
+
BasicTransformerBlock(
|
| 94 |
+
inner_dim,
|
| 95 |
+
self.num_attention_heads,
|
| 96 |
+
self.attention_head_dim,
|
| 97 |
+
dropout=self.cfg.dropout,
|
| 98 |
+
cross_attention_dim=self.cfg.cross_attention_dim,
|
| 99 |
+
activation_fn=self.cfg.activation_fn,
|
| 100 |
+
attention_bias=self.cfg.attention_bias,
|
| 101 |
+
only_cross_attention=self.cfg.only_cross_attention,
|
| 102 |
+
double_self_attention=self.cfg.double_self_attention,
|
| 103 |
+
upcast_attention=self.cfg.upcast_attention,
|
| 104 |
+
norm_type=self.cfg.norm_type,
|
| 105 |
+
norm_elementwise_affine=self.cfg.norm_elementwise_affine,
|
| 106 |
+
)
|
| 107 |
+
for d in range(self.cfg.num_layers)
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# 4. Define output layers
|
| 112 |
+
self.out_channels = (
|
| 113 |
+
self.cfg.in_channels
|
| 114 |
+
if self.cfg.out_channels is None
|
| 115 |
+
else self.cfg.out_channels
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
|
| 119 |
+
|
| 120 |
+
self.gradient_checkpointing = self.cfg.gradient_checkpointing
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self,
|
| 124 |
+
hidden_states: torch.Tensor,
|
| 125 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 126 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 127 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 128 |
+
):
|
| 129 |
+
"""
|
| 130 |
+
The [`Transformer1DModel`] forward method.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
| 134 |
+
Input `hidden_states`.
|
| 135 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
| 136 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
| 137 |
+
self-attention.
|
| 138 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
| 139 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 140 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 141 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 142 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
| 143 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
| 144 |
+
|
| 145 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
| 146 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
| 147 |
+
|
| 148 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
| 149 |
+
above. This bias will be added to the cross-attention scores.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
torch.FloatTensor
|
| 153 |
+
"""
|
| 154 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
| 155 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
| 156 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
| 157 |
+
# expects mask of shape:
|
| 158 |
+
# [batch, key_tokens]
|
| 159 |
+
# adds singleton query_tokens dimension:
|
| 160 |
+
# [batch, 1, key_tokens]
|
| 161 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 162 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 163 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 164 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
| 165 |
+
# assume that mask is expressed as:
|
| 166 |
+
# (1 = keep, 0 = discard)
|
| 167 |
+
# convert mask into a bias that can be added to attention scores:
|
| 168 |
+
# (keep = +0, discard = -10000.0)
|
| 169 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
| 170 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 171 |
+
|
| 172 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 173 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
| 174 |
+
encoder_attention_mask = (
|
| 175 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
| 176 |
+
) * -10000.0
|
| 177 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 178 |
+
|
| 179 |
+
# 1. Input
|
| 180 |
+
batch, _, seq_len = hidden_states.shape
|
| 181 |
+
residual = hidden_states
|
| 182 |
+
|
| 183 |
+
hidden_states = self.norm(hidden_states)
|
| 184 |
+
inner_dim = hidden_states.shape[1]
|
| 185 |
+
hidden_states = hidden_states.permute(0, 2, 1).reshape(
|
| 186 |
+
batch, seq_len, inner_dim
|
| 187 |
+
)
|
| 188 |
+
hidden_states = self.proj_in(hidden_states)
|
| 189 |
+
|
| 190 |
+
# 2. Blocks
|
| 191 |
+
for block in self.transformer_blocks:
|
| 192 |
+
if self.training and self.gradient_checkpointing:
|
| 193 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 194 |
+
block,
|
| 195 |
+
hidden_states,
|
| 196 |
+
attention_mask,
|
| 197 |
+
encoder_hidden_states,
|
| 198 |
+
encoder_attention_mask,
|
| 199 |
+
use_reentrant=False,
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
hidden_states = block(
|
| 203 |
+
hidden_states,
|
| 204 |
+
attention_mask=attention_mask,
|
| 205 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 206 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# 3. Output
|
| 210 |
+
hidden_states = self.proj_out(hidden_states)
|
| 211 |
+
hidden_states = (
|
| 212 |
+
hidden_states.reshape(batch, seq_len, inner_dim)
|
| 213 |
+
.permute(0, 2, 1)
|
| 214 |
+
.contiguous()
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
output = hidden_states + residual
|
| 218 |
+
|
| 219 |
+
return output
|
tsr/system.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import PIL.Image
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import trimesh
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
from .models.isosurface import MarchingCubeHelper
|
| 17 |
+
from .utils import (
|
| 18 |
+
BaseModule,
|
| 19 |
+
ImagePreprocessor,
|
| 20 |
+
find_class,
|
| 21 |
+
get_spherical_cameras,
|
| 22 |
+
scale_tensor,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TSR(BaseModule):
|
| 27 |
+
@dataclass
|
| 28 |
+
class Config(BaseModule.Config):
|
| 29 |
+
cond_image_size: int
|
| 30 |
+
|
| 31 |
+
image_tokenizer_cls: str
|
| 32 |
+
image_tokenizer: dict
|
| 33 |
+
|
| 34 |
+
tokenizer_cls: str
|
| 35 |
+
tokenizer: dict
|
| 36 |
+
|
| 37 |
+
backbone_cls: str
|
| 38 |
+
backbone: dict
|
| 39 |
+
|
| 40 |
+
post_processor_cls: str
|
| 41 |
+
post_processor: dict
|
| 42 |
+
|
| 43 |
+
decoder_cls: str
|
| 44 |
+
decoder: dict
|
| 45 |
+
|
| 46 |
+
renderer_cls: str
|
| 47 |
+
renderer: dict
|
| 48 |
+
|
| 49 |
+
cfg: Config
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
def from_pretrained(
|
| 53 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
| 54 |
+
):
|
| 55 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 56 |
+
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
| 57 |
+
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
| 58 |
+
else:
|
| 59 |
+
config_path = hf_hub_download(
|
| 60 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
| 61 |
+
)
|
| 62 |
+
weight_path = hf_hub_download(
|
| 63 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
cfg = OmegaConf.load(config_path)
|
| 67 |
+
OmegaConf.resolve(cfg)
|
| 68 |
+
model = cls(cfg)
|
| 69 |
+
ckpt = torch.load(weight_path, map_location="cpu")
|
| 70 |
+
model.load_state_dict(ckpt)
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
def configure(self):
|
| 74 |
+
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
|
| 75 |
+
self.cfg.image_tokenizer
|
| 76 |
+
)
|
| 77 |
+
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
| 78 |
+
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
|
| 79 |
+
self.post_processor = find_class(self.cfg.post_processor_cls)(
|
| 80 |
+
self.cfg.post_processor
|
| 81 |
+
)
|
| 82 |
+
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
|
| 83 |
+
self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
|
| 84 |
+
self.image_processor = ImagePreprocessor()
|
| 85 |
+
self.isosurface_helper = None
|
| 86 |
+
|
| 87 |
+
def forward(
|
| 88 |
+
self,
|
| 89 |
+
image: Union[
|
| 90 |
+
PIL.Image.Image,
|
| 91 |
+
np.ndarray,
|
| 92 |
+
torch.FloatTensor,
|
| 93 |
+
List[PIL.Image.Image],
|
| 94 |
+
List[np.ndarray],
|
| 95 |
+
List[torch.FloatTensor],
|
| 96 |
+
],
|
| 97 |
+
device: str,
|
| 98 |
+
) -> torch.FloatTensor:
|
| 99 |
+
rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
|
| 100 |
+
device
|
| 101 |
+
)
|
| 102 |
+
batch_size = rgb_cond.shape[0]
|
| 103 |
+
|
| 104 |
+
input_image_tokens: torch.Tensor = self.image_tokenizer(
|
| 105 |
+
rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
input_image_tokens = rearrange(
|
| 109 |
+
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
tokens: torch.Tensor = self.tokenizer(batch_size)
|
| 113 |
+
|
| 114 |
+
tokens = self.backbone(
|
| 115 |
+
tokens,
|
| 116 |
+
encoder_hidden_states=input_image_tokens,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
|
| 120 |
+
return scene_codes
|
| 121 |
+
|
| 122 |
+
def render(
|
| 123 |
+
self,
|
| 124 |
+
scene_codes,
|
| 125 |
+
n_views: int,
|
| 126 |
+
elevation_deg: float = 0.0,
|
| 127 |
+
camera_distance: float = 1.9,
|
| 128 |
+
fovy_deg: float = 40.0,
|
| 129 |
+
height: int = 256,
|
| 130 |
+
width: int = 256,
|
| 131 |
+
return_type: str = "pil",
|
| 132 |
+
):
|
| 133 |
+
rays_o, rays_d = get_spherical_cameras(
|
| 134 |
+
n_views, elevation_deg, camera_distance, fovy_deg, height, width
|
| 135 |
+
)
|
| 136 |
+
rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
|
| 137 |
+
|
| 138 |
+
def process_output(image: torch.FloatTensor):
|
| 139 |
+
if return_type == "pt":
|
| 140 |
+
return image
|
| 141 |
+
elif return_type == "np":
|
| 142 |
+
return image.detach().cpu().numpy()
|
| 143 |
+
elif return_type == "pil":
|
| 144 |
+
return Image.fromarray(
|
| 145 |
+
(image.detach().cpu().numpy() * 255.0).astype(np.uint8)
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
raise NotImplementedError
|
| 149 |
+
|
| 150 |
+
images = []
|
| 151 |
+
for scene_code in scene_codes:
|
| 152 |
+
images_ = []
|
| 153 |
+
for i in range(n_views):
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
image = self.renderer(
|
| 156 |
+
self.decoder, scene_code, rays_o[i], rays_d[i]
|
| 157 |
+
)
|
| 158 |
+
images_.append(process_output(image))
|
| 159 |
+
images.append(images_)
|
| 160 |
+
|
| 161 |
+
return images
|
| 162 |
+
|
| 163 |
+
def set_marching_cubes_resolution(self, resolution: int):
|
| 164 |
+
if (
|
| 165 |
+
self.isosurface_helper is not None
|
| 166 |
+
and self.isosurface_helper.resolution == resolution
|
| 167 |
+
):
|
| 168 |
+
return
|
| 169 |
+
self.isosurface_helper = MarchingCubeHelper(resolution)
|
| 170 |
+
|
| 171 |
+
def extract_mesh(self, scene_codes, has_vertex_color, resolution: int = 256, threshold: float = 25.0):
|
| 172 |
+
self.set_marching_cubes_resolution(resolution)
|
| 173 |
+
meshes = []
|
| 174 |
+
for scene_code in scene_codes:
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
density = self.renderer.query_triplane(
|
| 177 |
+
self.decoder,
|
| 178 |
+
scale_tensor(
|
| 179 |
+
self.isosurface_helper.grid_vertices.to(scene_codes.device),
|
| 180 |
+
self.isosurface_helper.points_range,
|
| 181 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
| 182 |
+
),
|
| 183 |
+
scene_code,
|
| 184 |
+
)["density_act"]
|
| 185 |
+
v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
|
| 186 |
+
v_pos = scale_tensor(
|
| 187 |
+
v_pos,
|
| 188 |
+
self.isosurface_helper.points_range,
|
| 189 |
+
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
|
| 190 |
+
)
|
| 191 |
+
color = None
|
| 192 |
+
if has_vertex_color:
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
color = self.renderer.query_triplane(
|
| 195 |
+
self.decoder,
|
| 196 |
+
v_pos,
|
| 197 |
+
scene_code,
|
| 198 |
+
)["color"]
|
| 199 |
+
mesh = trimesh.Trimesh(
|
| 200 |
+
vertices=v_pos.cpu().numpy(),
|
| 201 |
+
faces=t_pos_idx.cpu().numpy(),
|
| 202 |
+
vertex_colors=color.cpu().numpy() if has_vertex_color else None,
|
| 203 |
+
)
|
| 204 |
+
meshes.append(mesh)
|
| 205 |
+
return meshes
|
tsr/utils.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import math
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import imageio
|
| 8 |
+
import numpy as np
|
| 9 |
+
import PIL.Image
|
| 10 |
+
import rembg
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import trimesh
|
| 15 |
+
from omegaconf import DictConfig, OmegaConf
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
| 20 |
+
scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
|
| 21 |
+
return scfg
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def find_class(cls_string):
|
| 25 |
+
module_string = ".".join(cls_string.split(".")[:-1])
|
| 26 |
+
cls_name = cls_string.split(".")[-1]
|
| 27 |
+
module = importlib.import_module(module_string, package=None)
|
| 28 |
+
cls = getattr(module, cls_name)
|
| 29 |
+
return cls
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_intrinsic_from_fov(fov, H, W, bs=-1):
|
| 33 |
+
focal_length = 0.5 * H / np.tan(0.5 * fov)
|
| 34 |
+
intrinsic = np.identity(3, dtype=np.float32)
|
| 35 |
+
intrinsic[0, 0] = focal_length
|
| 36 |
+
intrinsic[1, 1] = focal_length
|
| 37 |
+
intrinsic[0, 2] = W / 2.0
|
| 38 |
+
intrinsic[1, 2] = H / 2.0
|
| 39 |
+
|
| 40 |
+
if bs > 0:
|
| 41 |
+
intrinsic = intrinsic[None].repeat(bs, axis=0)
|
| 42 |
+
|
| 43 |
+
return torch.from_numpy(intrinsic)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class BaseModule(nn.Module):
|
| 47 |
+
@dataclass
|
| 48 |
+
class Config:
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
cfg: Config # add this to every subclass of BaseModule to enable static type checking
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
|
| 55 |
+
) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.cfg = parse_structured(self.Config, cfg)
|
| 58 |
+
self.configure(*args, **kwargs)
|
| 59 |
+
|
| 60 |
+
def configure(self, *args, **kwargs) -> None:
|
| 61 |
+
raise NotImplementedError
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ImagePreprocessor:
|
| 65 |
+
def convert_and_resize(
|
| 66 |
+
self,
|
| 67 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
| 68 |
+
size: int,
|
| 69 |
+
):
|
| 70 |
+
if isinstance(image, PIL.Image.Image):
|
| 71 |
+
image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
| 72 |
+
elif isinstance(image, np.ndarray):
|
| 73 |
+
if image.dtype == np.uint8:
|
| 74 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0)
|
| 75 |
+
else:
|
| 76 |
+
image = torch.from_numpy(image)
|
| 77 |
+
elif isinstance(image, torch.Tensor):
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
batched = image.ndim == 4
|
| 81 |
+
|
| 82 |
+
if not batched:
|
| 83 |
+
image = image[None, ...]
|
| 84 |
+
image = F.interpolate(
|
| 85 |
+
image.permute(0, 3, 1, 2),
|
| 86 |
+
(size, size),
|
| 87 |
+
mode="bilinear",
|
| 88 |
+
align_corners=False,
|
| 89 |
+
antialias=True,
|
| 90 |
+
).permute(0, 2, 3, 1)
|
| 91 |
+
if not batched:
|
| 92 |
+
image = image[0]
|
| 93 |
+
return image
|
| 94 |
+
|
| 95 |
+
def __call__(
|
| 96 |
+
self,
|
| 97 |
+
image: Union[
|
| 98 |
+
PIL.Image.Image,
|
| 99 |
+
np.ndarray,
|
| 100 |
+
torch.FloatTensor,
|
| 101 |
+
List[PIL.Image.Image],
|
| 102 |
+
List[np.ndarray],
|
| 103 |
+
List[torch.FloatTensor],
|
| 104 |
+
],
|
| 105 |
+
size: int,
|
| 106 |
+
) -> Any:
|
| 107 |
+
if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
|
| 108 |
+
image = self.convert_and_resize(image, size)
|
| 109 |
+
else:
|
| 110 |
+
if not isinstance(image, list):
|
| 111 |
+
image = [image]
|
| 112 |
+
image = [self.convert_and_resize(im, size) for im in image]
|
| 113 |
+
image = torch.stack(image, dim=0)
|
| 114 |
+
return image
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def rays_intersect_bbox(
|
| 118 |
+
rays_o: torch.Tensor,
|
| 119 |
+
rays_d: torch.Tensor,
|
| 120 |
+
radius: float,
|
| 121 |
+
near: float = 0.0,
|
| 122 |
+
valid_thresh: float = 0.01,
|
| 123 |
+
):
|
| 124 |
+
input_shape = rays_o.shape[:-1]
|
| 125 |
+
rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
|
| 126 |
+
rays_d_valid = torch.where(
|
| 127 |
+
rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
|
| 128 |
+
)
|
| 129 |
+
if type(radius) in [int, float]:
|
| 130 |
+
radius = torch.FloatTensor(
|
| 131 |
+
[[-radius, radius], [-radius, radius], [-radius, radius]]
|
| 132 |
+
).to(rays_o.device)
|
| 133 |
+
radius = (
|
| 134 |
+
1.0 - 1.0e-3
|
| 135 |
+
) * radius # tighten the radius to make sure the intersection point lies in the bounding box
|
| 136 |
+
interx0 = (radius[..., 1] - rays_o) / rays_d_valid
|
| 137 |
+
interx1 = (radius[..., 0] - rays_o) / rays_d_valid
|
| 138 |
+
t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
|
| 139 |
+
t_far = torch.maximum(interx0, interx1).amin(dim=-1)
|
| 140 |
+
|
| 141 |
+
# check wheter a ray intersects the bbox or not
|
| 142 |
+
rays_valid = t_far - t_near > valid_thresh
|
| 143 |
+
|
| 144 |
+
t_near[torch.where(~rays_valid)] = 0.0
|
| 145 |
+
t_far[torch.where(~rays_valid)] = 0.0
|
| 146 |
+
|
| 147 |
+
t_near = t_near.view(*input_shape, 1)
|
| 148 |
+
t_far = t_far.view(*input_shape, 1)
|
| 149 |
+
rays_valid = rays_valid.view(*input_shape)
|
| 150 |
+
|
| 151 |
+
return t_near, t_far, rays_valid
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
|
| 155 |
+
if chunk_size <= 0:
|
| 156 |
+
return func(*args, **kwargs)
|
| 157 |
+
B = None
|
| 158 |
+
for arg in list(args) + list(kwargs.values()):
|
| 159 |
+
if isinstance(arg, torch.Tensor):
|
| 160 |
+
B = arg.shape[0]
|
| 161 |
+
break
|
| 162 |
+
assert (
|
| 163 |
+
B is not None
|
| 164 |
+
), "No tensor found in args or kwargs, cannot determine batch size."
|
| 165 |
+
out = defaultdict(list)
|
| 166 |
+
out_type = None
|
| 167 |
+
# max(1, B) to support B == 0
|
| 168 |
+
for i in range(0, max(1, B), chunk_size):
|
| 169 |
+
out_chunk = func(
|
| 170 |
+
*[
|
| 171 |
+
arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
| 172 |
+
for arg in args
|
| 173 |
+
],
|
| 174 |
+
**{
|
| 175 |
+
k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
| 176 |
+
for k, arg in kwargs.items()
|
| 177 |
+
},
|
| 178 |
+
)
|
| 179 |
+
if out_chunk is None:
|
| 180 |
+
continue
|
| 181 |
+
out_type = type(out_chunk)
|
| 182 |
+
if isinstance(out_chunk, torch.Tensor):
|
| 183 |
+
out_chunk = {0: out_chunk}
|
| 184 |
+
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
|
| 185 |
+
chunk_length = len(out_chunk)
|
| 186 |
+
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
|
| 187 |
+
elif isinstance(out_chunk, dict):
|
| 188 |
+
pass
|
| 189 |
+
else:
|
| 190 |
+
print(
|
| 191 |
+
f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
|
| 192 |
+
)
|
| 193 |
+
exit(1)
|
| 194 |
+
for k, v in out_chunk.items():
|
| 195 |
+
v = v if torch.is_grad_enabled() else v.detach()
|
| 196 |
+
out[k].append(v)
|
| 197 |
+
|
| 198 |
+
if out_type is None:
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
out_merged: Dict[Any, Optional[torch.Tensor]] = {}
|
| 202 |
+
for k, v in out.items():
|
| 203 |
+
if all([vv is None for vv in v]):
|
| 204 |
+
# allow None in return value
|
| 205 |
+
out_merged[k] = None
|
| 206 |
+
elif all([isinstance(vv, torch.Tensor) for vv in v]):
|
| 207 |
+
out_merged[k] = torch.cat(v, dim=0)
|
| 208 |
+
else:
|
| 209 |
+
raise TypeError(
|
| 210 |
+
f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if out_type is torch.Tensor:
|
| 214 |
+
return out_merged[0]
|
| 215 |
+
elif out_type in [tuple, list]:
|
| 216 |
+
return out_type([out_merged[i] for i in range(chunk_length)])
|
| 217 |
+
elif out_type is dict:
|
| 218 |
+
return out_merged
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
ValidScale = Union[Tuple[float, float], torch.FloatTensor]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
|
| 225 |
+
if inp_scale is None:
|
| 226 |
+
inp_scale = (0, 1)
|
| 227 |
+
if tgt_scale is None:
|
| 228 |
+
tgt_scale = (0, 1)
|
| 229 |
+
if isinstance(tgt_scale, torch.FloatTensor):
|
| 230 |
+
assert dat.shape[-1] == tgt_scale.shape[-1]
|
| 231 |
+
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
| 232 |
+
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
| 233 |
+
return dat
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def get_activation(name) -> Callable:
|
| 237 |
+
if name is None:
|
| 238 |
+
return lambda x: x
|
| 239 |
+
name = name.lower()
|
| 240 |
+
if name == "none":
|
| 241 |
+
return lambda x: x
|
| 242 |
+
elif name == "exp":
|
| 243 |
+
return lambda x: torch.exp(x)
|
| 244 |
+
elif name == "sigmoid":
|
| 245 |
+
return lambda x: torch.sigmoid(x)
|
| 246 |
+
elif name == "tanh":
|
| 247 |
+
return lambda x: torch.tanh(x)
|
| 248 |
+
elif name == "softplus":
|
| 249 |
+
return lambda x: F.softplus(x)
|
| 250 |
+
else:
|
| 251 |
+
try:
|
| 252 |
+
return getattr(F, name)
|
| 253 |
+
except AttributeError:
|
| 254 |
+
raise ValueError(f"Unknown activation function: {name}")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_ray_directions(
|
| 258 |
+
H: int,
|
| 259 |
+
W: int,
|
| 260 |
+
focal: Union[float, Tuple[float, float]],
|
| 261 |
+
principal: Optional[Tuple[float, float]] = None,
|
| 262 |
+
use_pixel_centers: bool = True,
|
| 263 |
+
normalize: bool = True,
|
| 264 |
+
) -> torch.FloatTensor:
|
| 265 |
+
"""
|
| 266 |
+
Get ray directions for all pixels in camera coordinate.
|
| 267 |
+
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
|
| 268 |
+
ray-tracing-generating-camera-rays/standard-coordinate-systems
|
| 269 |
+
|
| 270 |
+
Inputs:
|
| 271 |
+
H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
|
| 272 |
+
Outputs:
|
| 273 |
+
directions: (H, W, 3), the direction of the rays in camera coordinate
|
| 274 |
+
"""
|
| 275 |
+
pixel_center = 0.5 if use_pixel_centers else 0
|
| 276 |
+
|
| 277 |
+
if isinstance(focal, float):
|
| 278 |
+
fx, fy = focal, focal
|
| 279 |
+
cx, cy = W / 2, H / 2
|
| 280 |
+
else:
|
| 281 |
+
fx, fy = focal
|
| 282 |
+
assert principal is not None
|
| 283 |
+
cx, cy = principal
|
| 284 |
+
|
| 285 |
+
i, j = torch.meshgrid(
|
| 286 |
+
torch.arange(W, dtype=torch.float32) + pixel_center,
|
| 287 |
+
torch.arange(H, dtype=torch.float32) + pixel_center,
|
| 288 |
+
indexing="xy",
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
|
| 292 |
+
|
| 293 |
+
if normalize:
|
| 294 |
+
directions = F.normalize(directions, dim=-1)
|
| 295 |
+
|
| 296 |
+
return directions
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def get_rays(
|
| 300 |
+
directions,
|
| 301 |
+
c2w,
|
| 302 |
+
keepdim=False,
|
| 303 |
+
normalize=False,
|
| 304 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 305 |
+
# Rotate ray directions from camera coordinate to the world coordinate
|
| 306 |
+
assert directions.shape[-1] == 3
|
| 307 |
+
|
| 308 |
+
if directions.ndim == 2: # (N_rays, 3)
|
| 309 |
+
if c2w.ndim == 2: # (4, 4)
|
| 310 |
+
c2w = c2w[None, :, :]
|
| 311 |
+
assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
|
| 312 |
+
rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
|
| 313 |
+
rays_o = c2w[:, :3, 3].expand(rays_d.shape)
|
| 314 |
+
elif directions.ndim == 3: # (H, W, 3)
|
| 315 |
+
assert c2w.ndim in [2, 3]
|
| 316 |
+
if c2w.ndim == 2: # (4, 4)
|
| 317 |
+
rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
|
| 318 |
+
-1
|
| 319 |
+
) # (H, W, 3)
|
| 320 |
+
rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
|
| 321 |
+
elif c2w.ndim == 3: # (B, 4, 4)
|
| 322 |
+
rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
| 323 |
+
-1
|
| 324 |
+
) # (B, H, W, 3)
|
| 325 |
+
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
| 326 |
+
elif directions.ndim == 4: # (B, H, W, 3)
|
| 327 |
+
assert c2w.ndim == 3 # (B, 4, 4)
|
| 328 |
+
rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
| 329 |
+
-1
|
| 330 |
+
) # (B, H, W, 3)
|
| 331 |
+
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
| 332 |
+
|
| 333 |
+
if normalize:
|
| 334 |
+
rays_d = F.normalize(rays_d, dim=-1)
|
| 335 |
+
if not keepdim:
|
| 336 |
+
rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
|
| 337 |
+
|
| 338 |
+
return rays_o, rays_d
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def get_spherical_cameras(
|
| 342 |
+
n_views: int,
|
| 343 |
+
elevation_deg: float,
|
| 344 |
+
camera_distance: float,
|
| 345 |
+
fovy_deg: float,
|
| 346 |
+
height: int,
|
| 347 |
+
width: int,
|
| 348 |
+
):
|
| 349 |
+
# Use 0 to 360*(n_views-1)/n_views to avoid duplicate first/last position
|
| 350 |
+
# This ensures full 360-degree coverage without overlap
|
| 351 |
+
azimuth_deg = torch.linspace(0, 360.0 * (n_views - 1) / n_views, n_views)
|
| 352 |
+
elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
|
| 353 |
+
camera_distances = torch.full_like(elevation_deg, camera_distance)
|
| 354 |
+
|
| 355 |
+
elevation = elevation_deg * math.pi / 180
|
| 356 |
+
azimuth = azimuth_deg * math.pi / 180
|
| 357 |
+
|
| 358 |
+
# convert spherical coordinates to cartesian coordinates
|
| 359 |
+
# right hand coordinate system, x back, y right, z up
|
| 360 |
+
# elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
|
| 361 |
+
camera_positions = torch.stack(
|
| 362 |
+
[
|
| 363 |
+
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
| 364 |
+
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
| 365 |
+
camera_distances * torch.sin(elevation),
|
| 366 |
+
],
|
| 367 |
+
dim=-1,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# default scene center at origin
|
| 371 |
+
center = torch.zeros_like(camera_positions)
|
| 372 |
+
# default camera up direction as +z
|
| 373 |
+
up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
|
| 374 |
+
|
| 375 |
+
fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
|
| 376 |
+
|
| 377 |
+
lookat = F.normalize(center - camera_positions, dim=-1)
|
| 378 |
+
right = F.normalize(torch.cross(lookat, up), dim=-1)
|
| 379 |
+
up = F.normalize(torch.cross(right, lookat), dim=-1)
|
| 380 |
+
c2w3x4 = torch.cat(
|
| 381 |
+
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
| 382 |
+
dim=-1,
|
| 383 |
+
)
|
| 384 |
+
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
|
| 385 |
+
c2w[:, 3, 3] = 1.0
|
| 386 |
+
|
| 387 |
+
# get directions by dividing directions_unit_focal by focal length
|
| 388 |
+
focal_length = 0.5 * height / torch.tan(0.5 * fovy)
|
| 389 |
+
directions_unit_focal = get_ray_directions(
|
| 390 |
+
H=height,
|
| 391 |
+
W=width,
|
| 392 |
+
focal=1.0,
|
| 393 |
+
)
|
| 394 |
+
directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
|
| 395 |
+
directions[:, :, :, :2] = (
|
| 396 |
+
directions[:, :, :, :2] / focal_length[:, None, None, None]
|
| 397 |
+
)
|
| 398 |
+
# must use normalize=True to normalize directions here
|
| 399 |
+
rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
|
| 400 |
+
|
| 401 |
+
return rays_o, rays_d
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def remove_background(
|
| 405 |
+
image: PIL.Image.Image,
|
| 406 |
+
rembg_session: Any = None,
|
| 407 |
+
force: bool = False,
|
| 408 |
+
**rembg_kwargs,
|
| 409 |
+
) -> PIL.Image.Image:
|
| 410 |
+
do_remove = True
|
| 411 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
| 412 |
+
do_remove = False
|
| 413 |
+
do_remove = do_remove or force
|
| 414 |
+
if do_remove:
|
| 415 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
| 416 |
+
return image
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def resize_foreground(
|
| 420 |
+
image: PIL.Image.Image,
|
| 421 |
+
ratio: float,
|
| 422 |
+
) -> PIL.Image.Image:
|
| 423 |
+
image = np.array(image)
|
| 424 |
+
assert image.shape[-1] == 4
|
| 425 |
+
alpha = np.where(image[..., 3] > 0)
|
| 426 |
+
y1, y2, x1, x2 = (
|
| 427 |
+
alpha[0].min(),
|
| 428 |
+
alpha[0].max(),
|
| 429 |
+
alpha[1].min(),
|
| 430 |
+
alpha[1].max(),
|
| 431 |
+
)
|
| 432 |
+
# crop the foreground
|
| 433 |
+
fg = image[y1:y2, x1:x2]
|
| 434 |
+
# pad to square
|
| 435 |
+
size = max(fg.shape[0], fg.shape[1])
|
| 436 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
| 437 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
| 438 |
+
new_image = np.pad(
|
| 439 |
+
fg,
|
| 440 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
| 441 |
+
mode="constant",
|
| 442 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# compute padding according to the ratio
|
| 446 |
+
new_size = int(new_image.shape[0] / ratio)
|
| 447 |
+
# pad to size, double side
|
| 448 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
| 449 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
| 450 |
+
new_image = np.pad(
|
| 451 |
+
new_image,
|
| 452 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
| 453 |
+
mode="constant",
|
| 454 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
| 455 |
+
)
|
| 456 |
+
new_image = PIL.Image.fromarray(new_image)
|
| 457 |
+
return new_image
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def save_video(
|
| 461 |
+
frames: List[PIL.Image.Image],
|
| 462 |
+
output_path: str,
|
| 463 |
+
fps: int = 30,
|
| 464 |
+
):
|
| 465 |
+
# use imageio to save video
|
| 466 |
+
frames = [np.array(frame) for frame in frames]
|
| 467 |
+
writer = imageio.get_writer(output_path, fps=fps)
|
| 468 |
+
for frame in frames:
|
| 469 |
+
writer.append_data(frame)
|
| 470 |
+
writer.close()
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def to_gradio_3d_orientation(mesh):
|
| 474 |
+
mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
|
| 475 |
+
mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
|
| 476 |
+
return mesh
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def to_standard_3d_orientation(mesh):
|
| 480 |
+
"""
|
| 481 |
+
Convert mesh to standard 3D viewer orientation (Y-up, Z-forward).
|
| 482 |
+
This is a more standard orientation that works better with most 3D viewers.
|
| 483 |
+
"""
|
| 484 |
+
# Rotate -90 degrees around X axis (to make Y up instead of Z)
|
| 485 |
+
mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
|
| 486 |
+
return mesh
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def apply_mesh_orientation(mesh, orientation="standard"):
|
| 490 |
+
"""
|
| 491 |
+
Apply orientation transformation to mesh.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
mesh: Trimesh mesh object
|
| 495 |
+
orientation: Orientation type
|
| 496 |
+
- "standard": Standard 3D viewer orientation (Y-up, Z-forward)
|
| 497 |
+
- "gradio": Gradio 3D viewer orientation
|
| 498 |
+
- "none": No transformation (original orientation)
|
| 499 |
+
|
| 500 |
+
Returns:
|
| 501 |
+
Transformed mesh
|
| 502 |
+
"""
|
| 503 |
+
if orientation == "standard":
|
| 504 |
+
return to_standard_3d_orientation(mesh)
|
| 505 |
+
elif orientation == "gradio":
|
| 506 |
+
return to_gradio_3d_orientation(mesh)
|
| 507 |
+
elif orientation == "none":
|
| 508 |
+
return mesh
|
| 509 |
+
else:
|
| 510 |
+
raise ValueError(f"Unknown orientation: {orientation}. Must be 'standard', 'gradio', or 'none'")
|