telcom commited on
Commit
10ac12a
·
verified ·
1 Parent(s): 62ff71a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -57
app.py CHANGED
@@ -10,22 +10,11 @@ import torch
10
  from diffusers import (
11
  StableDiffusionPipeline,
12
  StableDiffusionImg2ImgPipeline,
13
- StableDiffusionXLPipeline,
14
- StableDiffusionXLImg2ImgPipeline,
15
  EulerAncestralDiscreteScheduler,
16
  )
 
17
  from huggingface_hub import login
18
 
19
- # ============================================================
20
- # Optional GPU decorator (Spaces)
21
- # ============================================================
22
- try:
23
- import spaces
24
- GPU_DECORATOR = spaces.GPU
25
- except Exception:
26
- def GPU_DECORATOR(fn):
27
- return fn
28
-
29
  # ============================================================
30
  # Config
31
  # ============================================================
@@ -41,42 +30,38 @@ device = torch.device("cuda" if cuda_available else "cpu")
41
  dtype = torch.float16 if cuda_available else torch.float32
42
 
43
  MAX_SEED = np.iinfo(np.int32).max
44
- MAX_IMAGE_SIZE = 1216 if cuda_available else 768
45
 
46
  pipe_txt2img = None
47
  pipe_img2img = None
48
- is_sdxl = False
49
  model_loaded = False
50
  load_error = None
51
 
52
  # ============================================================
53
- # Load model (AUTO detect SDXL vs SD)
54
  # ============================================================
55
  try:
56
- from_pretrained_kwargs = dict(
57
- torch_dtype=dtype,
58
  revision=REVISION,
59
- )
60
-
61
- if HF_TOKEN:
62
- from_pretrained_kwargs["token"] = HF_TOKEN
63
 
64
- # Try SDXL first
65
- try:
66
- pipe_txt2img = StableDiffusionXLPipeline.from_pretrained(
67
- MODEL_ID, **from_pretrained_kwargs
68
- )
69
- is_sdxl = True
70
- except Exception:
71
- pipe_txt2img = StableDiffusionPipeline.from_pretrained(
72
- MODEL_ID, **from_pretrained_kwargs
73
- )
74
- is_sdxl = False
75
 
 
76
  pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
77
  pipe_txt2img.scheduler.config
78
  )
79
- pipe_txt2img = pipe_txt2img.to(device)
80
 
81
  # Memory optimisations
82
  try:
@@ -92,16 +77,15 @@ try:
92
 
93
  pipe_txt2img.set_progress_bar_config(disable=True)
94
 
95
- # Create img2img pipeline
96
- if is_sdxl:
97
- pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components)
98
- else:
99
- pipe_img2img = StableDiffusionImg2ImgPipeline(**pipe_txt2img.components)
100
-
101
  pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
102
  pipe_img2img.scheduler.config
103
  )
104
- pipe_img2img = pipe_img2img.to(device)
 
 
 
105
 
106
  model_loaded = True
107
 
@@ -112,13 +96,12 @@ except Exception as e:
112
  # ============================================================
113
  # Helpers
114
  # ============================================================
115
- def _make_error_image(w, h, text):
116
  return Image.new("RGB", (w, h), (30, 30, 40))
117
 
118
  # ============================================================
119
  # Inference
120
  # ============================================================
121
- @GPU_DECORATOR
122
  def infer(
123
  prompt,
124
  negative_prompt,
@@ -135,19 +118,13 @@ def infer(
135
  height = int(height)
136
 
137
  if not model_loaded:
138
- return _make_error_image(width, height, "Model not loaded"), load_error
139
 
140
  if randomize_seed:
141
  seed = random.randint(0, MAX_SEED)
142
 
143
  generator = torch.Generator(device=device).manual_seed(seed)
144
 
145
- common_kwargs = dict(
146
- guidance_scale=float(guidance_scale),
147
- num_inference_steps=int(num_inference_steps),
148
- generator=generator,
149
- )
150
-
151
  try:
152
  with torch.inference_mode():
153
  if init_image is not None:
@@ -156,7 +133,9 @@ def infer(
156
  negative_prompt=negative_prompt,
157
  image=init_image,
158
  strength=float(strength),
159
- **common_kwargs,
 
 
160
  ).images[0]
161
  else:
162
  image = pipe_txt2img(
@@ -164,13 +143,15 @@ def infer(
164
  negative_prompt=negative_prompt,
165
  width=width,
166
  height=height,
167
- **common_kwargs,
 
 
168
  ).images[0]
169
 
170
- return image, f"Seed: {seed} | {'SDXL' if is_sdxl else 'SD 1.x'}"
171
 
172
  except Exception as e:
173
- return _make_error_image(width, height, "Generation error"), str(e)
174
 
175
  finally:
176
  gc.collect()
@@ -180,8 +161,7 @@ def infer(
180
  # ============================================================
181
  # UI
182
  # ============================================================
183
- with gr.Blocks(title="Text-to-Image / Image-to-Image") as demo:
184
-
185
  gr.Markdown("## Stable Diffusion Generator")
186
 
187
  if not model_loaded:
@@ -196,11 +176,11 @@ with gr.Blocks(title="Text-to-Image / Image-to-Image") as demo:
196
 
197
  with gr.Accordion("Advanced Settings", open=False):
198
  negative_prompt = gr.Textbox(label="Negative prompt", value="")
199
- seed = gr.Slider(0, MAX_SEED, value=0, step=1, label="Seed")
200
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
201
  width = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Width")
202
  height = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Height")
203
- guidance_scale = gr.Slider(0, 20, step=0.1, value=7.5, label="Guidance scale")
204
  num_inference_steps = gr.Slider(1, 40, step=1, value=20, label="Steps")
205
  strength = gr.Slider(0.0, 1.0, step=0.05, value=0.7, label="Image strength")
206
 
 
10
  from diffusers import (
11
  StableDiffusionPipeline,
12
  StableDiffusionImg2ImgPipeline,
 
 
13
  EulerAncestralDiscreteScheduler,
14
  )
15
+ from transformers import CLIPTokenizer, CLIPTextModel
16
  from huggingface_hub import login
17
 
 
 
 
 
 
 
 
 
 
 
18
  # ============================================================
19
  # Config
20
  # ============================================================
 
30
  dtype = torch.float16 if cuda_available else torch.float32
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
33
+ MAX_IMAGE_SIZE = 768 if not cuda_available else 1024
34
 
35
  pipe_txt2img = None
36
  pipe_img2img = None
 
37
  model_loaded = False
38
  load_error = None
39
 
40
  # ============================================================
41
+ # Load model (FORCED tokenizer fix)
42
  # ============================================================
43
  try:
44
+ pipe_txt2img = StableDiffusionPipeline.from_pretrained(
45
+ MODEL_ID,
46
  revision=REVISION,
47
+ torch_dtype=dtype,
48
+ safety_checker=None,
49
+ ).to(device)
 
50
 
51
+ # 🔑 FORCE tokenizer + text encoder
52
+ pipe_txt2img.tokenizer = CLIPTokenizer.from_pretrained(
53
+ MODEL_ID, subfolder="tokenizer"
54
+ )
55
+ pipe_txt2img.text_encoder = CLIPTextModel.from_pretrained(
56
+ MODEL_ID,
57
+ subfolder="text_encoder",
58
+ torch_dtype=dtype,
59
+ ).to(device)
 
 
60
 
61
+ # Scheduler
62
  pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
63
  pipe_txt2img.scheduler.config
64
  )
 
65
 
66
  # Memory optimisations
67
  try:
 
77
 
78
  pipe_txt2img.set_progress_bar_config(disable=True)
79
 
80
+ # Img2Img pipeline (share components)
81
+ pipe_img2img = StableDiffusionImg2ImgPipeline(**pipe_txt2img.components).to(device)
 
 
 
 
82
  pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
83
  pipe_img2img.scheduler.config
84
  )
85
+
86
+ # Defensive checks
87
+ assert pipe_txt2img.tokenizer is not None
88
+ assert pipe_txt2img.text_encoder is not None
89
 
90
  model_loaded = True
91
 
 
96
  # ============================================================
97
  # Helpers
98
  # ============================================================
99
+ def _make_error_image(w, h):
100
  return Image.new("RGB", (w, h), (30, 30, 40))
101
 
102
  # ============================================================
103
  # Inference
104
  # ============================================================
 
105
  def infer(
106
  prompt,
107
  negative_prompt,
 
118
  height = int(height)
119
 
120
  if not model_loaded:
121
+ return _make_error_image(width, height), load_error
122
 
123
  if randomize_seed:
124
  seed = random.randint(0, MAX_SEED)
125
 
126
  generator = torch.Generator(device=device).manual_seed(seed)
127
 
 
 
 
 
 
 
128
  try:
129
  with torch.inference_mode():
130
  if init_image is not None:
 
133
  negative_prompt=negative_prompt,
134
  image=init_image,
135
  strength=float(strength),
136
+ guidance_scale=float(guidance_scale),
137
+ num_inference_steps=int(num_inference_steps),
138
+ generator=generator,
139
  ).images[0]
140
  else:
141
  image = pipe_txt2img(
 
143
  negative_prompt=negative_prompt,
144
  width=width,
145
  height=height,
146
+ guidance_scale=float(guidance_scale),
147
+ num_inference_steps=int(num_inference_steps),
148
+ generator=generator,
149
  ).images[0]
150
 
151
+ return image, f"Seed: {seed}"
152
 
153
  except Exception as e:
154
+ return _make_error_image(width, height), str(e)
155
 
156
  finally:
157
  gc.collect()
 
161
  # ============================================================
162
  # UI
163
  # ============================================================
164
+ with gr.Blocks(title="Stable Diffusion (Unlearning Model)") as demo:
 
165
  gr.Markdown("## Stable Diffusion Generator")
166
 
167
  if not model_loaded:
 
176
 
177
  with gr.Accordion("Advanced Settings", open=False):
178
  negative_prompt = gr.Textbox(label="Negative prompt", value="")
179
+ seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
180
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
181
  width = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Width")
182
  height = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Height")
183
+ guidance_scale = gr.Slider(1, 20, step=0.5, value=7.5, label="Guidance scale")
184
  num_inference_steps = gr.Slider(1, 40, step=1, value=20, label="Steps")
185
  strength = gr.Slider(0.0, 1.0, step=0.05, value=0.7, label="Image strength")
186