Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import shlex | |
| import subprocess | |
| from pathlib import Path | |
| from tempfile import TemporaryDirectory | |
| from textwrap import dedent | |
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| from transformers import CLIPTokenizer | |
| def hex_to_rgb(s: str) -> tuple[int, int, int]: | |
| value = s.lstrip("#") | |
| return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16)) | |
| st.header("Color Textual Inversion") | |
| with st.expander(label="info"): | |
| with open("info.txt", "r", encoding="utf-8") as f: | |
| st.markdown(f.read()) | |
| duplicate_button = """<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Bingsu/color_textual_inversion?duplicate=true"><img style="margin: 0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>""" | |
| st.markdown(duplicate_button, unsafe_allow_html=True) | |
| col1, col2 = st.columns([15, 85]) | |
| color = col1.color_picker("Pick a color", "#00f900") | |
| col2.text_input("", color, disabled=True) | |
| emb_name = st.text_input("Embedding name", color.lstrip("#").upper()) | |
| init_token = st.text_input("Initializer token", "init token name") | |
| rgb = hex_to_rgb(color) | |
| img_array = np.zeros((128, 128, 3), dtype=np.uint8) | |
| for i in range(3): | |
| img_array[..., i] = rgb[i] | |
| dataset_temp = TemporaryDirectory(prefix="dataset_", dir=".") | |
| dataset_path = Path(dataset_temp.name) | |
| output_temp = TemporaryDirectory(prefix="output_", dir=".") | |
| output_path = Path(output_temp.name) | |
| img_path = dataset_path / f"{emb_name}.png" | |
| Image.fromarray(img_array).save(img_path) | |
| with st.sidebar: | |
| model_name = st.text_input("Model name", "Linaqruf/anything-v3.0") | |
| steps = st.slider("Steps", 1, 100, value=1, step=1) | |
| learning_rate = st.text_input("Learning rate", "0.001") | |
| learning_rate = float(learning_rate) | |
| tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") | |
| # case 1: init_token is not a single token | |
| token = tokenizer.tokenize(init_token) | |
| if len(token) > 1: | |
| st.warning("Initializer token must be a single token") | |
| st.stop() | |
| # case 2: init_token already exists in the tokenizer | |
| num_added_tokens = tokenizer.add_tokens(emb_name) | |
| if num_added_tokens == 0: | |
| st.warning(f"The tokenizer already contains the token {emb_name}") | |
| st.stop() | |
| cmd = """ | |
| accelerate launch textual_inversion.py \ | |
| --pretrained_model_name_or_path={model_name} \ | |
| --train_data_dir={dataset_path} \ | |
| --learnable_property="style" \ | |
| --placeholder_token="{emb_name}" \ | |
| --initializer_token="{init}" \ | |
| --resolution=128 \ | |
| --train_batch_size=1 \ | |
| --repeats=1 \ | |
| --gradient_accumulation_steps=1 \ | |
| --max_train_steps={steps} \ | |
| --learning_rate={lr} \ | |
| --output_dir={output_path} \ | |
| --only_save_embeds | |
| """.strip() | |
| cmd = dedent(cmd).format( | |
| model_name=model_name, | |
| dataset_path=dataset_path.as_posix(), | |
| emb_name=emb_name, | |
| init=init_token, | |
| steps=steps, | |
| lr=learning_rate, | |
| output_path=output_path.as_posix(), | |
| ) | |
| cmd = shlex.split(cmd) | |
| result_path = output_path / "learned_embeds.bin" | |
| captured = "" | |
| start_button = st.button("Start") | |
| download_button = st.empty() | |
| if start_button: | |
| with st.spinner("Training..."): | |
| placeholder = st.empty() | |
| p = subprocess.Popen( | |
| cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8" | |
| ) | |
| while line := p.stderr.readline(): | |
| captured += line | |
| placeholder.code(captured, language="bash") | |
| if not result_path.exists(): | |
| st.stop() | |
| # fix unknown file volume bug | |
| trained_emb = torch.load(result_path, map_location="cpu") | |
| for k, v in trained_emb.items(): | |
| trained_emb[k] = torch.from_numpy(v.numpy()) | |
| torch.save(trained_emb, result_path) | |
| file = result_path.read_bytes() | |
| download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt") | |
| st.download_button(f"Download {emb_name}.pt ", file, f"{emb_name}.pt") | |
| dataset_temp.cleanup() | |
| output_temp.cleanup() | |