darsoarafa commited on
Commit
845b21b
·
verified ·
1 Parent(s): c6dc8d8

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +1 -14
  2. app.py +253 -0
  3. banner.html +42 -0
  4. footer.html +47 -0
  5. requirements.txt +4 -0
  6. tips.html +26 -0
README.md CHANGED
@@ -1,14 +1 @@
1
- ---
2
- title: Ammar
3
- emoji: 👀
4
- colorFrom: indigo
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
- license: bsd
11
- short_description: Fitting Room
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # Ammar Virtual Try-On App
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+ import os
4
+ import time
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import requests
10
+ from gradio.themes.utils import sizes
11
+
12
+ # LOGGING
13
+ logger = logging.getLogger("TRYON")
14
+ logger.setLevel(logging.INFO)
15
+ handler = logging.StreamHandler()
16
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
17
+ handler.setFormatter(formatter)
18
+ logger.addHandler(handler)
19
+
20
+ # IMAGE ASSETS
21
+ ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
22
+
23
+ # API CONFIG
24
+ #FASHN_ENDPOINT_URL = os.environ.get("FASHN_ENDPOINT_URL", "https://api.fashn.ai/v1")
25
+ FASHN_ENDPOINT_URL = "https://api.fashn.ai/v1"
26
+ #FASHN_API_KEY = os.environ.get("FASHN_API_KEY")
27
+ FASHN_API_KEY = "fa-bXvHG3Z8zBBM-cUJuLvRFrFi00BD35ZIis5t7"
28
+ assert FASHN_ENDPOINT_URL, "Please set the FASHN_ENDPOINT_URL environment variable"
29
+ assert FASHN_API_KEY, "Please set the FASHN_API_KEY environment variable"
30
+
31
+ # ----------------- HELPER FUNCTIONS ----------------- #
32
+
33
+ CATEGORY_API_MAPPING = {"Top": "tops", "Bottom": "bottoms", "Full-body": "one-pieces"}
34
+
35
+
36
+ def opencv_load_image_from_http(url: str) -> np.ndarray:
37
+ """Loads an image from a given URL using HTTP GET."""
38
+ with requests.get(url) as response:
39
+ response.raise_for_status()
40
+ image_data = np.frombuffer(response.content, np.uint8)
41
+ image = cv2.imdecode(image_data, cv2.IMREAD_COLOR)
42
+ return image
43
+
44
+
45
+ def encode_img_to_base64(img: np.array) -> str:
46
+ """Encodes an image as a JPEG in Base64 format."""
47
+ img = cv2.imencode(".jpg", img)[1].tobytes()
48
+ img = base64.b64encode(img).decode("utf-8")
49
+ img = f"data:image/jpeg;base64,{img}"
50
+ return img
51
+
52
+
53
+ def parse_checkboxes(checkboxes):
54
+ checkboxes = [checkbox.lower().replace(" ", "_") for checkbox in checkboxes]
55
+ checkboxes = {checkbox: True for checkbox in checkboxes}
56
+ return checkboxes
57
+
58
+
59
+ def make_api_request(session, url, headers, data=None, method="GET", max_retries=3, timeout=60):
60
+ for attempt in range(max_retries):
61
+ try:
62
+ if method.upper() == "GET":
63
+ response = session.get(url, headers=headers, timeout=timeout)
64
+ elif method.upper() == "POST":
65
+ response = session.post(url, headers=headers, json=data, timeout=timeout)
66
+ else:
67
+ raise ValueError(f"Unsupported HTTP method: {method}")
68
+
69
+ response.raise_for_status()
70
+ return response.json()
71
+ except requests.exceptions.RequestException as e:
72
+ if attempt == max_retries - 1: # If it's the last attempt
73
+ raise Exception(f"API call failed after {max_retries} attempts: {str(e)}") from e
74
+ print(f"Attempt {attempt + 1} failed. Retrying...")
75
+ time.sleep(2) # Wait for 2 seconds before retrying
76
+
77
+
78
+ # ----------------- CORE FUNCTION ----------------- #
79
+
80
+
81
+ def get_tryon_result(
82
+ model_image,
83
+ garment_image,
84
+ garment_photo_type,
85
+ category,
86
+ nsfw_filter,
87
+ cover_feet,
88
+ adjust_hands,
89
+ restore_background,
90
+ restore_clothes,
91
+ guidance_scale,
92
+ timesteps,
93
+ seed,
94
+ num_samples,
95
+ ):
96
+ logger.info("Starting new try-on request...")
97
+
98
+ # preprocessing: convert to RGB, resize, encode to base64
99
+ model_image, garment_image = map(lambda x: cv2.cvtColor(x, cv2.COLOR_RGB2BGR), [model_image, garment_image])
100
+ model_image, garment_image = map(encode_img_to_base64, [model_image, garment_image])
101
+
102
+ # prepare data for API request
103
+ data = {
104
+ "model_image": model_image,
105
+ "garment_image": garment_image,
106
+ "garment_photo_type": garment_photo_type.lower(),
107
+ "category": CATEGORY_API_MAPPING[category],
108
+ "nsfw_filter": nsfw_filter,
109
+ "cover_feet": cover_feet,
110
+ "adjust_hands": adjust_hands,
111
+ "restore_background": restore_background,
112
+ "restore_clothes": restore_clothes,
113
+ "guidance_scale": guidance_scale,
114
+ "timesteps": timesteps,
115
+ "seed": seed,
116
+ "num_samples": num_samples,
117
+ }
118
+
119
+ # make API request
120
+ session = requests.Session()
121
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {FASHN_API_KEY}"}
122
+
123
+ try:
124
+ response_data = make_api_request(
125
+ session, f"{FASHN_ENDPOINT_URL}/run", headers=headers, data=data, method="POST"
126
+ )
127
+ pred_id = response_data.get("id")
128
+ logger.info(f"Prediction ID: {pred_id}")
129
+ except Exception as e:
130
+ raise gr.Error(f"Status check failed: {str(e)}")
131
+
132
+ # poll the status of the prediction
133
+ start_time = time.time()
134
+ while True:
135
+ if time.time() - start_time > 180: # 3 minutes timeout
136
+ raise gr.Error("Maximum polling time exceeded.")
137
+
138
+ try:
139
+ status_data = make_api_request(
140
+ session, f"{FASHN_ENDPOINT_URL}/status/{pred_id}", headers=headers, method="GET"
141
+ )
142
+ except Exception as e:
143
+ raise gr.Error(f"Status check failed: {str(e)}")
144
+
145
+ if status_data["status"] == "completed":
146
+ logger.info("Prediction completed.")
147
+ break
148
+ elif status_data["status"] not in ["starting", "in_queue", "processing"]:
149
+ raise gr.Error(f"Prediction failed with id {pred_id}: {status_data.get('error')}")
150
+
151
+ logger.info(f"Prediction status: {status_data['status']}")
152
+ time.sleep(3)
153
+
154
+ # get the result images
155
+ result_imgs = []
156
+ for output_url in status_data["output"]:
157
+ result_img = opencv_load_image_from_http(output_url)
158
+ result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
159
+ result_imgs.append(result_img)
160
+
161
+ return result_imgs
162
+
163
+
164
+ # ----------------- GRADIO UI ----------------- #
165
+
166
+
167
+ with open("banner.html", "r") as file:
168
+ banner = file.read()
169
+ with open("tips.html", "r") as file:
170
+ tips = file.read()
171
+ with open("footer.html", "r") as file:
172
+ footer = file.read()
173
+
174
+ CUSTOM_CSS = """
175
+ .image-container img {
176
+ max-width: 384px;
177
+ max-height: 576px;
178
+ margin: 0 auto;
179
+ border-radius: 0px;
180
+ .gradio-container {background-color: #fafafa}
181
+ """
182
+
183
+ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radius_md)) as demo:
184
+ gr.HTML(banner)
185
+ gr.HTML(tips)
186
+ with gr.Row():
187
+ with gr.Column():
188
+ model_image = gr.Image(label="Foto Model", type="numpy")
189
+
190
+ with gr.Accordion("Model Image Controls", open=False):
191
+ cover_feet = gr.Checkbox(label="Cover Feet", value=False)
192
+ adjust_hands = gr.Checkbox(label="Adjust Hands", value=False)
193
+ restore_background = gr.Checkbox(label="Restore Background", value=False)
194
+ restore_clothes = gr.Checkbox(label="Restore Clothes", value=False)
195
+ nsfw_filter = gr.Checkbox(label="NSFW Filter", value=True)
196
+
197
+ example_model = gr.Examples(label="Pilih model",
198
+ inputs=model_image,
199
+ examples_per_page=10,
200
+ examples=[
201
+ os.path.join(ASSETS_DIR, "models", img) for img in os.listdir(os.path.join(ASSETS_DIR, "models"))
202
+ ],
203
+ )
204
+ with gr.Column():
205
+ garment_image = gr.Image(label="Produk", type="numpy")
206
+ garment_photo_type = gr.Radio(
207
+ choices=["Auto", "Flat-Lay", "Model"], label="Select Photo Type", value="Auto"
208
+ )
209
+ category = gr.Radio(choices=["Top", "Bottom", "Full-body"], label="Select Category", value="Top")
210
+
211
+ example_garment = gr.Examples(label="Pilih produk",
212
+ inputs=garment_image,
213
+ examples_per_page=10,
214
+ examples=[
215
+ os.path.join(ASSETS_DIR, "garments", img)
216
+ for img in os.listdir(os.path.join(ASSETS_DIR, "garments"))
217
+ ],
218
+ )
219
+
220
+ with gr.Column():
221
+ result_gallery = gr.Gallery(label="Hasil", show_label=True, elem_id="gallery")
222
+ run_button = gr.Button("Coba")
223
+ with gr.Accordion("Sampling Controls", open=False):
224
+ guidance_scale = gr.Slider(minimum=1.5, maximum=3, value=2.0, step=0.1, label="Guidance Scale")
225
+ timesteps = gr.Slider(minimum=10, maximum=50, step=1, value=50, label="Timesteps")
226
+ seed = gr.Number(label="Seed", value=42, precision=0)
227
+ num_samples = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples")
228
+
229
+ run_button.click(
230
+ fn=get_tryon_result,
231
+ inputs=[
232
+ model_image,
233
+ garment_image,
234
+ garment_photo_type,
235
+ category,
236
+ nsfw_filter,
237
+ cover_feet,
238
+ adjust_hands,
239
+ restore_background,
240
+ restore_clothes,
241
+ guidance_scale,
242
+ timesteps,
243
+ seed,
244
+ num_samples,
245
+ ],
246
+ outputs=[result_gallery],
247
+ )
248
+
249
+ gr.HTML(footer)
250
+
251
+
252
+ if __name__ == "__main__":
253
+ demo.launch()
banner.html ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="
2
+ display: flex;
3
+ flex-direction: column;
4
+ justify-content: center;
5
+ align-items: center;
6
+ text-align: center;
7
+ background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
8
+ padding: 24px;
9
+ gap: 24px;
10
+ border-radius: 8px;
11
+ ">
12
+ <div style="display: flex; gap: 8px;">
13
+ <h1 style="
14
+ font-size: 48px;
15
+ color: #fafafa;
16
+ margin: 0;
17
+ font-family: Arial, Helvetica, sans-serif;
18
+ ">
19
+ Virtual Photoshoots
20
+ </h1>
21
+
22
+ </div>
23
+
24
+ <p style="
25
+ margin: 0;
26
+ line-height: 1.6rem;
27
+ font-size: 16px;
28
+ color: #fafafa;
29
+ opacity: 0.8;
30
+ ">
31
+ Virtual Fitting Room<br>
32
+ Dengan cara (1) pilih model atau upload foto diri, (2) pilih Kaos/Produk, (3) klik "Coba" <br>
33
+ </p>
34
+
35
+ <div style="
36
+ display: flex;
37
+ justify-content: center;
38
+ align-items: center;
39
+ text-align: center;
40
+ ">
41
+ </div>
42
+ </div>
footer.html ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="
2
+ display: flex;
3
+ flex-direction: column;
4
+ background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
5
+ padding: 24px;
6
+ gap: 24px;
7
+ border-radius: 8px;
8
+ align-items: center;
9
+ ">
10
+ <div style="display: flex; justify-content: center; gap: 8px;">
11
+ <h1 style="
12
+ font-size: 24px;
13
+ color: #fafafa;
14
+ margin: 0;
15
+ font-family: Arial, Helvetica, sans-serif;
16
+
17
+ ">
18
+
19
+ </h1>
20
+ </div>
21
+ <div style="max-width: 790px; text-align: center; display: flex; flex-direction: column; gap: 12px; font-family: Arial, Helvetica, sans-serif;
22
+ ">
23
+ <div>
24
+ <div style="text-align: center;">
25
+
26
+ </div>
27
+ <p style="
28
+ line-height: 1.6rem;
29
+ font-size: 16px;
30
+ color: #fafafa;
31
+ opacity: 0.8;
32
+ ">
33
+
34
+ </p>
35
+ </div>
36
+ <div>
37
+ <p style="
38
+ line-height: 1.6rem;
39
+ font-size: 16px;
40
+ color: #fafafa;
41
+ opacity: 0.8;
42
+ ">
43
+
44
+ </p>
45
+ </div>
46
+ </div>
47
+ </div>
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ numpy
3
+ requests
4
+ opencv-python
tips.html ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="
2
+ padding: 12px;
3
+ border: 1px solid #333333;
4
+ border-radius: 8px;
5
+ text-align: center;
6
+ display: flex;
7
+ flex-direction: column;
8
+ gap: 8px;
9
+ font-family: Arial, Helvetica, sans-serif;
10
+ ">
11
+ <b style="font-size: 18px;">Tips sukses "Coba Pakai"</b>
12
+
13
+ <ul style="
14
+ display: flex;
15
+ gap: 12px;
16
+ justify-content: center;
17
+ li {
18
+ margin: 0;
19
+ }
20
+ ">
21
+ <li>Rasio foto 2:3</li>
22
+ <li>Satu orang saja</li>
23
+ <li>Pose badan mirip</li>
24
+ <li>Fotonya jelas</li>
25
+ </ul>
26
+ </div>