Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/examples/test_examples.py
1440 views
1
# coding=utf-8
2
# Copyright 2023 HuggingFace Inc..
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
# http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
17
import logging
18
import os
19
import shutil
20
import subprocess
21
import sys
22
import tempfile
23
import unittest
24
from typing import List
25
26
from accelerate.utils import write_basic_config
27
28
from diffusers import DiffusionPipeline, UNet2DConditionModel
29
30
31
logging.basicConfig(level=logging.DEBUG)
32
33
logger = logging.getLogger()
34
35
36
# These utils relate to ensuring the right error message is received when running scripts
37
class SubprocessCallException(Exception):
38
pass
39
40
41
def run_command(command: List[str], return_stdout=False):
42
"""
43
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
44
if an error occurred while running `command`
45
"""
46
try:
47
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
48
if return_stdout:
49
if hasattr(output, "decode"):
50
output = output.decode("utf-8")
51
return output
52
except subprocess.CalledProcessError as e:
53
raise SubprocessCallException(
54
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
55
) from e
56
57
58
stream_handler = logging.StreamHandler(sys.stdout)
59
logger.addHandler(stream_handler)
60
61
62
class ExamplesTestsAccelerate(unittest.TestCase):
63
@classmethod
64
def setUpClass(cls):
65
super().setUpClass()
66
cls._tmpdir = tempfile.mkdtemp()
67
cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
68
69
write_basic_config(save_location=cls.configPath)
70
cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
71
72
@classmethod
73
def tearDownClass(cls):
74
super().tearDownClass()
75
shutil.rmtree(cls._tmpdir)
76
77
def test_train_unconditional(self):
78
with tempfile.TemporaryDirectory() as tmpdir:
79
test_args = f"""
80
examples/unconditional_image_generation/train_unconditional.py
81
--dataset_name hf-internal-testing/dummy_image_class_data
82
--model_config_name_or_path diffusers/ddpm_dummy
83
--resolution 64
84
--output_dir {tmpdir}
85
--train_batch_size 2
86
--num_epochs 1
87
--gradient_accumulation_steps 1
88
--ddpm_num_inference_steps 2
89
--learning_rate 1e-3
90
--lr_warmup_steps 5
91
""".split()
92
93
run_command(self._launch_args + test_args, return_stdout=True)
94
# save_pretrained smoke test
95
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
96
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
97
98
def test_textual_inversion(self):
99
with tempfile.TemporaryDirectory() as tmpdir:
100
test_args = f"""
101
examples/textual_inversion/textual_inversion.py
102
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
103
--train_data_dir docs/source/en/imgs
104
--learnable_property object
105
--placeholder_token <cat-toy>
106
--initializer_token a
107
--resolution 64
108
--train_batch_size 1
109
--gradient_accumulation_steps 1
110
--max_train_steps 2
111
--learning_rate 5.0e-04
112
--scale_lr
113
--lr_scheduler constant
114
--lr_warmup_steps 0
115
--output_dir {tmpdir}
116
""".split()
117
118
run_command(self._launch_args + test_args)
119
# save_pretrained smoke test
120
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin")))
121
122
def test_dreambooth(self):
123
with tempfile.TemporaryDirectory() as tmpdir:
124
test_args = f"""
125
examples/dreambooth/train_dreambooth.py
126
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
127
--instance_data_dir docs/source/en/imgs
128
--instance_prompt photo
129
--resolution 64
130
--train_batch_size 1
131
--gradient_accumulation_steps 1
132
--max_train_steps 2
133
--learning_rate 5.0e-04
134
--scale_lr
135
--lr_scheduler constant
136
--lr_warmup_steps 0
137
--output_dir {tmpdir}
138
""".split()
139
140
run_command(self._launch_args + test_args)
141
# save_pretrained smoke test
142
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
143
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
144
145
def test_dreambooth_checkpointing(self):
146
instance_prompt = "photo"
147
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
148
149
with tempfile.TemporaryDirectory() as tmpdir:
150
# Run training script with checkpointing
151
# max_train_steps == 5, checkpointing_steps == 2
152
# Should create checkpoints at steps 2, 4
153
154
initial_run_args = f"""
155
examples/dreambooth/train_dreambooth.py
156
--pretrained_model_name_or_path {pretrained_model_name_or_path}
157
--instance_data_dir docs/source/en/imgs
158
--instance_prompt {instance_prompt}
159
--resolution 64
160
--train_batch_size 1
161
--gradient_accumulation_steps 1
162
--max_train_steps 5
163
--learning_rate 5.0e-04
164
--scale_lr
165
--lr_scheduler constant
166
--lr_warmup_steps 0
167
--output_dir {tmpdir}
168
--checkpointing_steps=2
169
--seed=0
170
""".split()
171
172
run_command(self._launch_args + initial_run_args)
173
174
# check can run the original fully trained output pipeline
175
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
176
pipe(instance_prompt, num_inference_steps=2)
177
178
# check checkpoint directories exist
179
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
180
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
181
182
# check can run an intermediate checkpoint
183
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
184
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
185
pipe(instance_prompt, num_inference_steps=2)
186
187
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
188
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
189
190
# Run training script for 7 total steps resuming from checkpoint 4
191
192
resume_run_args = f"""
193
examples/dreambooth/train_dreambooth.py
194
--pretrained_model_name_or_path {pretrained_model_name_or_path}
195
--instance_data_dir docs/source/en/imgs
196
--instance_prompt {instance_prompt}
197
--resolution 64
198
--train_batch_size 1
199
--gradient_accumulation_steps 1
200
--max_train_steps 7
201
--learning_rate 5.0e-04
202
--scale_lr
203
--lr_scheduler constant
204
--lr_warmup_steps 0
205
--output_dir {tmpdir}
206
--checkpointing_steps=2
207
--resume_from_checkpoint=checkpoint-4
208
--seed=0
209
""".split()
210
211
run_command(self._launch_args + resume_run_args)
212
213
# check can run new fully trained pipeline
214
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
215
pipe(instance_prompt, num_inference_steps=2)
216
217
# check old checkpoints do not exist
218
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
219
220
# check new checkpoints exist
221
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
222
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
223
224
def test_text_to_image(self):
225
with tempfile.TemporaryDirectory() as tmpdir:
226
test_args = f"""
227
examples/text_to_image/train_text_to_image.py
228
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
229
--dataset_name hf-internal-testing/dummy_image_text_data
230
--resolution 64
231
--center_crop
232
--random_flip
233
--train_batch_size 1
234
--gradient_accumulation_steps 1
235
--max_train_steps 2
236
--learning_rate 5.0e-04
237
--scale_lr
238
--lr_scheduler constant
239
--lr_warmup_steps 0
240
--output_dir {tmpdir}
241
""".split()
242
243
run_command(self._launch_args + test_args)
244
# save_pretrained smoke test
245
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
246
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
247
248
def test_text_to_image_checkpointing(self):
249
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
250
prompt = "a prompt"
251
252
with tempfile.TemporaryDirectory() as tmpdir:
253
# Run training script with checkpointing
254
# max_train_steps == 5, checkpointing_steps == 2
255
# Should create checkpoints at steps 2, 4
256
257
initial_run_args = f"""
258
examples/text_to_image/train_text_to_image.py
259
--pretrained_model_name_or_path {pretrained_model_name_or_path}
260
--dataset_name hf-internal-testing/dummy_image_text_data
261
--resolution 64
262
--center_crop
263
--random_flip
264
--train_batch_size 1
265
--gradient_accumulation_steps 1
266
--max_train_steps 5
267
--learning_rate 5.0e-04
268
--scale_lr
269
--lr_scheduler constant
270
--lr_warmup_steps 0
271
--output_dir {tmpdir}
272
--checkpointing_steps=2
273
--seed=0
274
""".split()
275
276
run_command(self._launch_args + initial_run_args)
277
278
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
279
pipe(prompt, num_inference_steps=2)
280
281
# check checkpoint directories exist
282
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
283
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
284
285
# check can run an intermediate checkpoint
286
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
287
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
288
pipe(prompt, num_inference_steps=2)
289
290
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
291
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
292
293
# Run training script for 7 total steps resuming from checkpoint 4
294
295
resume_run_args = f"""
296
examples/text_to_image/train_text_to_image.py
297
--pretrained_model_name_or_path {pretrained_model_name_or_path}
298
--dataset_name hf-internal-testing/dummy_image_text_data
299
--resolution 64
300
--center_crop
301
--random_flip
302
--train_batch_size 1
303
--gradient_accumulation_steps 1
304
--max_train_steps 7
305
--learning_rate 5.0e-04
306
--scale_lr
307
--lr_scheduler constant
308
--lr_warmup_steps 0
309
--output_dir {tmpdir}
310
--checkpointing_steps=2
311
--resume_from_checkpoint=checkpoint-4
312
--seed=0
313
""".split()
314
315
run_command(self._launch_args + resume_run_args)
316
317
# check can run new fully trained pipeline
318
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
319
pipe(prompt, num_inference_steps=2)
320
321
# check old checkpoints do not exist
322
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
323
324
# check new checkpoints exist
325
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
326
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
327
328
def test_text_to_image_checkpointing_use_ema(self):
329
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
330
prompt = "a prompt"
331
332
with tempfile.TemporaryDirectory() as tmpdir:
333
# Run training script with checkpointing
334
# max_train_steps == 5, checkpointing_steps == 2
335
# Should create checkpoints at steps 2, 4
336
337
initial_run_args = f"""
338
examples/text_to_image/train_text_to_image.py
339
--pretrained_model_name_or_path {pretrained_model_name_or_path}
340
--dataset_name hf-internal-testing/dummy_image_text_data
341
--resolution 64
342
--center_crop
343
--random_flip
344
--train_batch_size 1
345
--gradient_accumulation_steps 1
346
--max_train_steps 5
347
--learning_rate 5.0e-04
348
--scale_lr
349
--lr_scheduler constant
350
--lr_warmup_steps 0
351
--output_dir {tmpdir}
352
--checkpointing_steps=2
353
--use_ema
354
--seed=0
355
""".split()
356
357
run_command(self._launch_args + initial_run_args)
358
359
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
360
pipe(prompt, num_inference_steps=2)
361
362
# check checkpoint directories exist
363
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
364
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
365
366
# check can run an intermediate checkpoint
367
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
368
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
369
pipe(prompt, num_inference_steps=2)
370
371
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
372
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
373
374
# Run training script for 7 total steps resuming from checkpoint 4
375
376
resume_run_args = f"""
377
examples/text_to_image/train_text_to_image.py
378
--pretrained_model_name_or_path {pretrained_model_name_or_path}
379
--dataset_name hf-internal-testing/dummy_image_text_data
380
--resolution 64
381
--center_crop
382
--random_flip
383
--train_batch_size 1
384
--gradient_accumulation_steps 1
385
--max_train_steps 7
386
--learning_rate 5.0e-04
387
--scale_lr
388
--lr_scheduler constant
389
--lr_warmup_steps 0
390
--output_dir {tmpdir}
391
--checkpointing_steps=2
392
--resume_from_checkpoint=checkpoint-4
393
--use_ema
394
--seed=0
395
""".split()
396
397
run_command(self._launch_args + resume_run_args)
398
399
# check can run new fully trained pipeline
400
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
401
pipe(prompt, num_inference_steps=2)
402
403
# check old checkpoints do not exist
404
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
405
406
# check new checkpoints exist
407
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
408
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
409
410