nroggendorff commited on
Commit
eac149d
·
verified ·
1 Parent(s): 4381bf7

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +76 -1340
train.py CHANGED
@@ -1,1370 +1,106 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """Fine-tuning script for Stable Diffusion XL for text2image."""
17
-
18
- import argparse
19
- import functools
20
- import gc
21
- import logging
22
- import math
23
- import os
24
- import random
25
- import shutil
26
- from contextlib import nullcontext
27
- from pathlib import Path
28
-
29
- import accelerate
30
- import datasets
31
- import numpy as np
32
  import torch
33
- import torch.nn.functional as F
34
- import torch.utils.checkpoint
35
- import transformers
36
- from accelerate import Accelerator
37
- from accelerate.logging import get_logger
38
- from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
39
- from datasets import concatenate_datasets, load_dataset
40
- from huggingface_hub import create_repo, upload_folder
41
- from packaging import version
42
- from torchvision import transforms
43
- from torchvision.transforms.functional import crop
44
- from tqdm.auto import tqdm
45
- from transformers import AutoTokenizer, PretrainedConfig
46
-
47
- import diffusers
48
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
49
- from diffusers.optimization import get_scheduler
50
- from diffusers.training_utils import EMAModel, compute_snr
51
- from diffusers.utils import check_min_version, is_wandb_available
52
- from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
53
- from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
54
- from diffusers.utils.torch_utils import is_compiled_module
55
-
56
-
57
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
- check_min_version("0.36.0.dev0")
59
-
60
- logger = get_logger(__name__)
61
- if is_torch_npu_available():
62
- import torch_npu
63
-
64
- torch.npu.config.allow_internal_format = False
65
-
66
- DATASET_NAME_MAPPING = {
67
- "lambdalabs/naruto-blip-captions": ("image", "text"),
68
- }
69
-
70
-
71
- def save_model_card(
72
- repo_id: str,
73
- images: list = None,
74
- validation_prompt: str = None,
75
- base_model: str = None,
76
- dataset_name: str = None,
77
- repo_folder: str = None,
78
- vae_path: str = None,
79
- ):
80
- img_str = ""
81
- if images is not None:
82
- for i, image in enumerate(images):
83
- image.save(os.path.join(repo_folder, f"image_{i}.png"))
84
- img_str += f"![img_{i}](./image_{i}.png)\n"
85
-
86
- model_description = f"""
87
- # Text-to-image finetuning - {repo_id}
88
-
89
- This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
90
- {img_str}
91
-
92
- Special VAE used for training: {vae_path}.
93
- """
94
-
95
- model_card = load_or_create_model_card(
96
- repo_id_or_path=repo_id,
97
- from_training=True,
98
- license="creativeml-openrail-m",
99
- base_model=base_model,
100
- model_description=model_description,
101
- inference=True,
102
- )
103
-
104
- tags = [
105
- "stable-diffusion-xl",
106
- "stable-diffusion-xl-diffusers",
107
- "text-to-image",
108
- "diffusers-training",
109
- "diffusers",
110
- ]
111
- model_card = populate_model_card(model_card, tags=tags)
112
-
113
- model_card.save(os.path.join(repo_folder, "README.md"))
114
-
115
-
116
- def import_model_class_from_model_name_or_path(
117
- pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
118
- ):
119
- text_encoder_config = PretrainedConfig.from_pretrained(
120
- pretrained_model_name_or_path, subfolder=subfolder, revision=revision
121
- )
122
- model_class = text_encoder_config.architectures[0]
123
-
124
- if model_class == "CLIPTextModel":
125
- from transformers import CLIPTextModel
126
-
127
- return CLIPTextModel
128
- elif model_class == "CLIPTextModelWithProjection":
129
- from transformers import CLIPTextModelWithProjection
130
-
131
- return CLIPTextModelWithProjection
132
- else:
133
- raise ValueError(f"{model_class} is not supported.")
134
-
135
-
136
- def parse_args(input_args=None):
137
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
138
- parser.add_argument(
139
- "--pretrained_model_name_or_path",
140
- type=str,
141
- default=None,
142
- required=True,
143
- help="Path to pretrained model or model identifier from huggingface.co/models.",
144
- )
145
- parser.add_argument(
146
- "--pretrained_vae_model_name_or_path",
147
- type=str,
148
- default=None,
149
- help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
150
- )
151
- parser.add_argument(
152
- "--revision",
153
- type=str,
154
- default=None,
155
- required=False,
156
- help="Revision of pretrained model identifier from huggingface.co/models.",
157
- )
158
- parser.add_argument(
159
- "--variant",
160
- type=str,
161
- default=None,
162
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
163
- )
164
- parser.add_argument(
165
- "--dataset_name",
166
- type=str,
167
- default=None,
168
- help=(
169
- "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
170
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
171
- " or to a folder containing files that 🤗 Datasets can understand."
172
- ),
173
- )
174
- parser.add_argument(
175
- "--dataset_config_name",
176
- type=str,
177
- default=None,
178
- help="The config of the Dataset, leave as None if there's only one config.",
179
- )
180
- parser.add_argument(
181
- "--train_data_dir",
182
- type=str,
183
- default=None,
184
- help=(
185
- "A folder containing the training data. Folder contents must follow the structure described in"
186
- " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
187
- " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
188
- ),
189
- )
190
- parser.add_argument(
191
- "--image_column", type=str, default="image", help="The column of the dataset containing an image."
192
- )
193
- parser.add_argument(
194
- "--caption_column",
195
- type=str,
196
- default="text",
197
- help="The column of the dataset containing a caption or a list of captions.",
198
- )
199
- parser.add_argument(
200
- "--validation_prompt",
201
- type=str,
202
- default=None,
203
- help="A prompt that is used during validation to verify that the model is learning.",
204
- )
205
- parser.add_argument(
206
- "--num_validation_images",
207
- type=int,
208
- default=4,
209
- help="Number of images that should be generated during validation with `validation_prompt`.",
210
- )
211
- parser.add_argument(
212
- "--validation_epochs",
213
- type=int,
214
- default=1,
215
- help=(
216
- "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
217
- " `args.validation_prompt` multiple times: `args.num_validation_images`."
218
- ),
219
- )
220
- parser.add_argument(
221
- "--max_train_samples",
222
- type=int,
223
- default=None,
224
- help=(
225
- "For debugging purposes or quicker training, truncate the number of training examples to this "
226
- "value if set."
227
- ),
228
- )
229
- parser.add_argument(
230
- "--proportion_empty_prompts",
231
- type=float,
232
- default=0,
233
- help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
234
- )
235
- parser.add_argument(
236
- "--output_dir",
237
- type=str,
238
- default="sdxl-model-finetuned",
239
- help="The output directory where the model predictions and checkpoints will be written.",
240
- )
241
- parser.add_argument(
242
- "--cache_dir",
243
- type=str,
244
- default=None,
245
- help="The directory where the downloaded models and datasets will be stored.",
246
- )
247
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
248
- parser.add_argument(
249
- "--resolution",
250
- type=int,
251
- default=1024,
252
- help=(
253
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
254
- " resolution"
255
- ),
256
- )
257
- parser.add_argument(
258
- "--center_crop",
259
- default=False,
260
- action="store_true",
261
- help=(
262
- "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
263
- " cropped. The images will be resized to the resolution first before cropping."
264
- ),
265
- )
266
- parser.add_argument(
267
- "--random_flip",
268
- action="store_true",
269
- help="whether to randomly flip images horizontally",
270
- )
271
- parser.add_argument(
272
- "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
273
- )
274
- parser.add_argument("--num_train_epochs", type=int, default=100)
275
- parser.add_argument(
276
- "--max_train_steps",
277
- type=int,
278
- default=None,
279
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
280
- )
281
- parser.add_argument(
282
- "--checkpointing_steps",
283
- type=int,
284
- default=500,
285
- help=(
286
- "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
287
- " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
288
- " training using `--resume_from_checkpoint`."
289
- ),
290
- )
291
- parser.add_argument(
292
- "--checkpoints_total_limit",
293
- type=int,
294
- default=None,
295
- help=("Max number of checkpoints to store."),
296
- )
297
- parser.add_argument(
298
- "--resume_from_checkpoint",
299
- type=str,
300
- default=None,
301
- help=(
302
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
303
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
304
- ),
305
- )
306
- parser.add_argument(
307
- "--gradient_accumulation_steps",
308
- type=int,
309
- default=1,
310
- help="Number of updates steps to accumulate before performing a backward/update pass.",
311
- )
312
- parser.add_argument(
313
- "--gradient_checkpointing",
314
- action="store_true",
315
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
316
- )
317
- parser.add_argument(
318
- "--learning_rate",
319
- type=float,
320
- default=1e-4,
321
- help="Initial learning rate (after the potential warmup period) to use.",
322
- )
323
- parser.add_argument(
324
- "--scale_lr",
325
- action="store_true",
326
- default=False,
327
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
328
- )
329
- parser.add_argument(
330
- "--lr_scheduler",
331
- type=str,
332
- default="constant",
333
- help=(
334
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
335
- ' "constant", "constant_with_warmup"]'
336
- ),
337
- )
338
- parser.add_argument(
339
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
340
- )
341
- parser.add_argument(
342
- "--timestep_bias_strategy",
343
- type=str,
344
- default="none",
345
- choices=["earlier", "later", "range", "none"],
346
- help=(
347
- "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
348
- " Choices: ['earlier', 'later', 'range', 'none']."
349
- " The default is 'none', which means no bias is applied, and training proceeds normally."
350
- " The value of 'later' will increase the frequency of the model's final training timesteps."
351
- ),
352
- )
353
- parser.add_argument(
354
- "--timestep_bias_multiplier",
355
- type=float,
356
- default=1.0,
357
- help=(
358
- "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
359
- " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
360
- ),
361
- )
362
- parser.add_argument(
363
- "--timestep_bias_begin",
364
- type=int,
365
- default=0,
366
- help=(
367
- "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
368
- " Defaults to zero, which equates to having no specific bias."
369
- ),
370
- )
371
- parser.add_argument(
372
- "--timestep_bias_end",
373
- type=int,
374
- default=1000,
375
- help=(
376
- "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
377
- " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
378
- ),
379
- )
380
- parser.add_argument(
381
- "--timestep_bias_portion",
382
- type=float,
383
- default=0.25,
384
- help=(
385
- "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
386
- " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
387
- " whether the biased portions are in the earlier or later timesteps."
388
- ),
389
- )
390
- parser.add_argument(
391
- "--snr_gamma",
392
- type=float,
393
- default=None,
394
- help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
395
- "More details here: https://huggingface.co/papers/2303.09556.",
396
- )
397
- parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
398
- parser.add_argument(
399
- "--allow_tf32",
400
- action="store_true",
401
- help=(
402
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
403
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
404
- ),
405
- )
406
- parser.add_argument(
407
- "--dataloader_num_workers",
408
- type=int,
409
- default=0,
410
- help=(
411
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
412
- ),
413
- )
414
- parser.add_argument(
415
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
416
- )
417
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
418
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
419
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
420
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
421
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
422
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
423
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
424
- parser.add_argument(
425
- "--prediction_type",
426
- type=str,
427
- default=None,
428
- help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
429
- )
430
- parser.add_argument(
431
- "--hub_model_id",
432
- type=str,
433
- default=None,
434
- help="The name of the repository to keep in sync with the local `output_dir`.",
435
- )
436
- parser.add_argument(
437
- "--logging_dir",
438
- type=str,
439
- default="logs",
440
- help=(
441
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
442
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
443
- ),
444
- )
445
- parser.add_argument(
446
- "--report_to",
447
- type=str,
448
- default="tensorboard",
449
- help=(
450
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
451
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
452
- ),
453
- )
454
- parser.add_argument(
455
- "--mixed_precision",
456
- type=str,
457
- default=None,
458
- choices=["no", "fp16", "bf16"],
459
- help=(
460
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
461
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
462
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
463
- ),
464
- )
465
- parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
466
- parser.add_argument(
467
- "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
468
- )
469
- parser.add_argument(
470
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
471
- )
472
- parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
473
- parser.add_argument(
474
- "--image_interpolation_mode",
475
- type=str,
476
- default="lanczos",
477
- choices=[
478
- f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
479
- ],
480
- help="The image interpolation method to use for resizing images.",
481
- )
482
-
483
- if input_args is not None:
484
- args = parser.parse_args(input_args)
485
- else:
486
- args = parser.parse_args()
487
-
488
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
489
- if env_local_rank != -1 and env_local_rank != args.local_rank:
490
- args.local_rank = env_local_rank
491
-
492
- # Sanity checks
493
- if args.dataset_name is None and args.train_data_dir is None:
494
- raise ValueError("Need either a dataset name or a training folder.")
495
- if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
496
- raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
497
-
498
- return args
499
-
500
-
501
- # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
502
- def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
503
- prompt_embeds_list = []
504
- prompt_batch = batch[caption_column]
505
-
506
- captions = []
507
- for caption in prompt_batch:
508
- if random.random() < proportion_empty_prompts:
509
- captions.append("")
510
- elif isinstance(caption, str):
511
- captions.append(caption)
512
- elif isinstance(caption, (list, np.ndarray)):
513
- # take a random caption if there are multiple
514
- captions.append(random.choice(caption) if is_train else caption[0])
515
-
516
- with torch.no_grad():
517
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
518
- text_inputs = tokenizer(
519
- captions,
520
- padding="max_length",
521
- max_length=tokenizer.model_max_length,
522
- truncation=True,
523
- return_tensors="pt",
524
- )
525
- text_input_ids = text_inputs.input_ids
526
- prompt_embeds = text_encoder(
527
- text_input_ids.to(text_encoder.device),
528
- output_hidden_states=True,
529
- return_dict=False,
530
- )
531
-
532
- # We are only ALWAYS interested in the pooled output of the final text encoder
533
- pooled_prompt_embeds = prompt_embeds[0]
534
- prompt_embeds = prompt_embeds[-1][-2]
535
- bs_embed, seq_len, _ = prompt_embeds.shape
536
- prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
537
- prompt_embeds_list.append(prompt_embeds)
538
-
539
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
540
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
541
- return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
542
-
543
-
544
- def compute_vae_encodings(batch, vae):
545
- images = batch.pop("pixel_values")
546
- pixel_values = torch.stack(list(images))
547
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
548
- pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
549
-
550
- with torch.no_grad():
551
- model_input = vae.encode(pixel_values).latent_dist.sample()
552
- model_input = model_input * vae.config.scaling_factor
553
-
554
- # There might have slightly performance improvement
555
- # by changing model_input.cpu() to accelerator.gather(model_input)
556
- return {"model_input": model_input.cpu()}
557
-
558
-
559
- def generate_timestep_weights(args, num_timesteps):
560
- weights = torch.ones(num_timesteps)
561
-
562
- # Determine the indices to bias
563
- num_to_bias = int(args.timestep_bias_portion * num_timesteps)
564
-
565
- if args.timestep_bias_strategy == "later":
566
- bias_indices = slice(-num_to_bias, None)
567
- elif args.timestep_bias_strategy == "earlier":
568
- bias_indices = slice(0, num_to_bias)
569
- elif args.timestep_bias_strategy == "range":
570
- # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
571
- range_begin = args.timestep_bias_begin
572
- range_end = args.timestep_bias_end
573
- if range_begin < 0:
574
- raise ValueError(
575
- "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
576
- )
577
- if range_end > num_timesteps:
578
- raise ValueError(
579
- "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
580
- )
581
- bias_indices = slice(range_begin, range_end)
582
- else: # 'none' or any other string
583
- return weights
584
- if args.timestep_bias_multiplier <= 0:
585
- return ValueError(
586
- "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
587
- " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
588
- " A timestep bias multiplier less than or equal to 0 is not allowed."
589
- )
590
-
591
- # Apply the bias
592
- weights[bias_indices] *= args.timestep_bias_multiplier
593
 
594
- # Normalize
595
- weights /= weights.sum()
596
 
597
- return weights
598
-
599
-
600
- def main(args):
601
- if args.report_to == "wandb" and args.hub_token is not None:
602
- raise ValueError(
603
- "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
604
- " Please use `hf auth login` to authenticate with the Hub."
605
- )
606
-
607
- logging_dir = Path(args.output_dir, args.logging_dir)
608
-
609
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
610
-
611
- if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
612
- # due to pytorch#99272, MPS does not yet support bfloat16.
613
- raise ValueError(
614
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
615
- )
616
-
617
- accelerator = Accelerator(
618
- gradient_accumulation_steps=args.gradient_accumulation_steps,
619
- mixed_precision=args.mixed_precision,
620
- log_with=args.report_to,
621
- project_config=accelerator_project_config,
622
- )
623
-
624
- # Disable AMP for MPS.
625
- if torch.backends.mps.is_available():
626
- accelerator.native_amp = False
627
-
628
- if args.report_to == "wandb":
629
- if not is_wandb_available():
630
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
631
- import wandb
632
-
633
- # Make one log on every process with the configuration for debugging.
634
- logging.basicConfig(
635
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
636
- datefmt="%m/%d/%Y %H:%M:%S",
637
- level=logging.INFO,
638
- )
639
- logger.info(accelerator.state, main_process_only=False)
640
- if accelerator.is_local_main_process:
641
- datasets.utils.logging.set_verbosity_warning()
642
- transformers.utils.logging.set_verbosity_warning()
643
- diffusers.utils.logging.set_verbosity_info()
644
- else:
645
- datasets.utils.logging.set_verbosity_error()
646
- transformers.utils.logging.set_verbosity_error()
647
- diffusers.utils.logging.set_verbosity_error()
648
-
649
- # If passed along, set the training seed now.
650
- if args.seed is not None:
651
- set_seed(args.seed)
652
-
653
- # Handle the repository creation
654
- if accelerator.is_main_process:
655
- if args.output_dir is not None:
656
- os.makedirs(args.output_dir, exist_ok=True)
657
-
658
- if args.push_to_hub:
659
- repo_id = create_repo(
660
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
661
- ).repo_id
662
-
663
- # Load the tokenizers
664
- tokenizer_one = AutoTokenizer.from_pretrained(
665
- args.pretrained_model_name_or_path,
666
- subfolder="tokenizer",
667
- revision=args.revision,
668
- use_fast=False,
669
- )
670
- tokenizer_two = AutoTokenizer.from_pretrained(
671
- args.pretrained_model_name_or_path,
672
- subfolder="tokenizer_2",
673
- revision=args.revision,
674
- use_fast=False,
675
- )
676
-
677
- # import correct text encoder classes
678
- text_encoder_cls_one = import_model_class_from_model_name_or_path(
679
- args.pretrained_model_name_or_path, args.revision
680
- )
681
- text_encoder_cls_two = import_model_class_from_model_name_or_path(
682
- args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
683
- )
684
-
685
- # Load scheduler and models
686
- noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
687
- # Check for terminal SNR in combination with SNR Gamma
688
- text_encoder_one = text_encoder_cls_one.from_pretrained(
689
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
690
- )
691
- text_encoder_two = text_encoder_cls_two.from_pretrained(
692
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
693
- )
694
- vae_path = (
695
- args.pretrained_model_name_or_path
696
- if args.pretrained_vae_model_name_or_path is None
697
- else args.pretrained_vae_model_name_or_path
698
- )
699
- vae = AutoencoderKL.from_pretrained(
700
- vae_path,
701
- subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
702
- revision=args.revision,
703
- variant=args.variant,
704
  )
705
- unet = UNet2DConditionModel.from_pretrained(
706
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
707
- )
708
-
709
- # Freeze vae and text encoders.
710
- vae.requires_grad_(False)
711
- text_encoder_one.requires_grad_(False)
712
- text_encoder_two.requires_grad_(False)
713
- # Set unet as trainable.
714
- unet.train()
715
-
716
- # For mixed precision training we cast all non-trainable weights to half-precision
717
- # as these weights are only used for inference, keeping weights in full precision is not required.
718
- weight_dtype = torch.float32
719
- if accelerator.mixed_precision == "fp16":
720
- weight_dtype = torch.float16
721
- elif accelerator.mixed_precision == "bf16":
722
- weight_dtype = torch.bfloat16
723
-
724
- # Move unet, vae and text_encoder to device and cast to weight_dtype
725
- # The VAE is in float32 to avoid NaN losses.
726
- vae.to(accelerator.device, dtype=torch.float32)
727
- text_encoder_one.to(accelerator.device, dtype=weight_dtype)
728
- text_encoder_two.to(accelerator.device, dtype=weight_dtype)
729
-
730
- # Create EMA for the unet.
731
- if args.use_ema:
732
- ema_unet = UNet2DConditionModel.from_pretrained(
733
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
734
- )
735
- ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
736
- if args.enable_npu_flash_attention:
737
- if is_torch_npu_available():
738
- logger.info("npu flash attention enabled.")
739
- unet.enable_npu_flash_attention()
740
- else:
741
- raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
742
- if args.enable_xformers_memory_efficient_attention:
743
- if is_xformers_available():
744
- import xformers
745
-
746
- xformers_version = version.parse(xformers.__version__)
747
- if xformers_version == version.parse("0.0.16"):
748
- logger.warning(
749
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
750
- )
751
- unet.enable_xformers_memory_efficient_attention()
752
- else:
753
- raise ValueError("xformers is not available. Make sure it is installed correctly")
754
-
755
- # `accelerate` 0.16.0 will have better support for customized saving
756
- if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
757
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
758
- def save_model_hook(models, weights, output_dir):
759
- if accelerator.is_main_process:
760
- if args.use_ema:
761
- ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
762
-
763
- for i, model in enumerate(models):
764
- model.save_pretrained(os.path.join(output_dir, "unet"))
765
-
766
- # make sure to pop weight so that corresponding model is not saved again
767
- if weights:
768
- weights.pop()
769
-
770
- def load_model_hook(models, input_dir):
771
- if args.use_ema:
772
- load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
773
- ema_unet.load_state_dict(load_model.state_dict())
774
- ema_unet.to(accelerator.device)
775
- del load_model
776
-
777
- for _ in range(len(models)):
778
- # pop models so that they are not loaded again
779
- model = models.pop()
780
-
781
- # load diffusers style into model
782
- load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
783
- model.register_to_config(**load_model.config)
784
 
785
- model.load_state_dict(load_model.state_dict())
786
- del load_model
787
 
788
- accelerator.register_save_state_pre_hook(save_model_hook)
789
- accelerator.register_load_state_pre_hook(load_model_hook)
790
-
791
- if args.gradient_checkpointing:
792
- unet.enable_gradient_checkpointing()
793
-
794
- # Enable TF32 for faster training on Ampere GPUs,
795
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
796
- if args.allow_tf32:
797
- torch.backends.cuda.matmul.allow_tf32 = True
798
-
799
- if args.scale_lr:
800
- args.learning_rate = (
801
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
802
- )
803
-
804
- # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
805
- if args.use_8bit_adam:
806
- try:
807
- import bitsandbytes as bnb
808
- except ImportError:
809
- raise ImportError(
810
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
811
- )
812
-
813
- optimizer_class = bnb.optim.AdamW8bit
814
- else:
815
- optimizer_class = torch.optim.AdamW
816
-
817
- # Optimizer creation
818
- params_to_optimize = unet.parameters()
819
- optimizer = optimizer_class(
820
- params_to_optimize,
821
- lr=args.learning_rate,
822
- betas=(args.adam_beta1, args.adam_beta2),
823
- weight_decay=args.adam_weight_decay,
824
- eps=args.adam_epsilon,
825
  )
826
 
827
- # Get the datasets: you can either provide your own training and evaluation files (see below)
828
- # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
829
-
830
- # In distributed training, the load_dataset function guarantees that only one local process can concurrently
831
- # download the dataset.
832
- if args.dataset_name is not None:
833
- # Downloading and loading a dataset from the hub.
834
- dataset = load_dataset(
835
- args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
836
- )
837
- else:
838
- data_files = {}
839
- if args.train_data_dir is not None:
840
- data_files["train"] = os.path.join(args.train_data_dir, "**")
841
- dataset = load_dataset(
842
- "imagefolder",
843
- data_files=data_files,
844
- cache_dir=args.cache_dir,
845
- )
846
- # See more about loading custom images at
847
- # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
848
-
849
- # Preprocessing the datasets.
850
- # We need to tokenize inputs and targets.
851
- column_names = dataset["train"].column_names
852
-
853
- # 6. Get the column names for input/target.
854
- dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
855
- if args.image_column is None:
856
- image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
857
- else:
858
- image_column = args.image_column
859
- if image_column not in column_names:
860
- raise ValueError(
861
- f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
862
- )
863
- if args.caption_column is None:
864
- caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
865
- else:
866
- caption_column = args.caption_column
867
- if caption_column not in column_names:
868
- raise ValueError(
869
- f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
870
- )
871
-
872
- # Preprocessing the datasets.
873
- interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
874
- if interpolation is None:
875
- raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
876
- train_resize = transforms.Resize(args.resolution, interpolation=interpolation)
877
- train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
878
- train_flip = transforms.RandomHorizontalFlip(p=1.0)
879
- train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
880
-
881
- def preprocess_train(examples):
882
- images = [image.convert("RGB") for image in examples[image_column]]
883
- # image aug
884
- original_sizes = []
885
- all_images = []
886
- crop_top_lefts = []
887
- for image in images:
888
- original_sizes.append((image.height, image.width))
889
- image = train_resize(image)
890
- if args.random_flip and random.random() < 0.5:
891
- # flip
892
- image = train_flip(image)
893
- if args.center_crop:
894
- y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
895
- x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
896
- image = train_crop(image)
897
- else:
898
- y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
899
- image = crop(image, y1, x1, h, w)
900
- crop_top_left = (y1, x1)
901
- crop_top_lefts.append(crop_top_left)
902
- image = train_transforms(image)
903
- all_images.append(image)
904
-
905
- examples["original_sizes"] = original_sizes
906
- examples["crop_top_lefts"] = crop_top_lefts
907
- examples["pixel_values"] = all_images
908
- return examples
909
 
910
- with accelerator.main_process_first():
911
- if args.max_train_samples is not None:
912
- dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
913
- # Set the training transforms
914
- train_dataset = dataset["train"].with_transform(preprocess_train)
915
 
916
- # Let's first compute all the embeddings so that we can free up the text encoders
917
- # from memory. We will pre-compute the VAE encodings too.
918
- text_encoders = [text_encoder_one, text_encoder_two]
919
- tokenizers = [tokenizer_one, tokenizer_two]
920
- compute_embeddings_fn = functools.partial(
921
- encode_prompt,
922
- text_encoders=text_encoders,
923
- tokenizers=tokenizers,
924
- proportion_empty_prompts=args.proportion_empty_prompts,
925
- caption_column=args.caption_column,
926
- )
927
- compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
928
- with accelerator.main_process_first():
929
- from datasets.fingerprint import Hasher
930
-
931
- # fingerprint used by the cache for the other processes to load the result
932
- # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
933
- new_fingerprint = Hasher.hash(args)
934
- new_fingerprint_for_vae = Hasher.hash((vae_path, args))
935
- train_dataset_with_embeddings = train_dataset.map(
936
- compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
937
- )
938
- train_dataset_with_vae = train_dataset.map(
939
- compute_vae_encodings_fn,
940
- batched=True,
941
- batch_size=args.train_batch_size,
942
- new_fingerprint=new_fingerprint_for_vae,
943
- )
944
- precomputed_dataset = concatenate_datasets(
945
- [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
946
- )
947
- precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
948
-
949
- del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
950
- del text_encoders, tokenizers, vae
951
- gc.collect()
952
- if is_torch_npu_available():
953
- torch_npu.npu.empty_cache()
954
- elif torch.cuda.is_available():
955
- torch.cuda.empty_cache()
956
-
957
- def collate_fn(examples):
958
- model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
959
- original_sizes = [example["original_sizes"] for example in examples]
960
- crop_top_lefts = [example["crop_top_lefts"] for example in examples]
961
- prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
962
- pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
963
-
964
- return {
965
- "model_input": model_input,
966
- "prompt_embeds": prompt_embeds,
967
- "pooled_prompt_embeds": pooled_prompt_embeds,
968
- "original_sizes": original_sizes,
969
- "crop_top_lefts": crop_top_lefts,
970
  }
 
 
971
 
972
- # DataLoaders creation:
973
- train_dataloader = torch.utils.data.DataLoader(
974
- precomputed_dataset,
975
- shuffle=True,
976
- collate_fn=collate_fn,
977
- batch_size=args.train_batch_size,
978
- num_workers=args.dataloader_num_workers,
979
- )
980
-
981
- # Scheduler and math around the number of training steps.
982
- overrode_max_train_steps = False
983
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
984
- if args.max_train_steps is None:
985
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
986
- overrode_max_train_steps = True
987
-
988
- lr_scheduler = get_scheduler(
989
- args.lr_scheduler,
990
- optimizer=optimizer,
991
- num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
992
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
993
- )
994
 
995
- # Prepare everything with our `accelerator`.
996
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
997
- unet, optimizer, train_dataloader, lr_scheduler
998
- )
999
 
1000
- if args.use_ema:
1001
- ema_unet.to(accelerator.device)
1002
 
1003
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1004
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1005
- if overrode_max_train_steps:
1006
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1007
- # Afterwards we recalculate our number of training epochs
1008
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1009
 
1010
- # We need to initialize the trackers we use, and also store our configuration.
1011
- # The trackers initializes automatically on the main process.
1012
- if accelerator.is_main_process:
1013
- accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
 
1014
 
1015
- # Function for unwrapping if torch.compile() was used in accelerate.
1016
- def unwrap_model(model):
1017
- model = accelerator.unwrap_model(model)
1018
- model = model._orig_mod if is_compiled_module(model) else model
1019
- return model
1020
 
1021
- if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
1022
- autocast_ctx = nullcontext()
1023
- else:
1024
- autocast_ctx = torch.autocast(accelerator.device.type)
1025
 
1026
- # Train!
1027
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1028
 
1029
- logger.info("***** Running training *****")
1030
- logger.info(f" Num examples = {len(precomputed_dataset)}")
1031
- logger.info(f" Num Epochs = {args.num_train_epochs}")
1032
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1033
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1034
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1035
- logger.info(f" Total optimization steps = {args.max_train_steps}")
1036
- global_step = 0
1037
- first_epoch = 0
1038
 
1039
- # Potentially load in the weights and states from a previous save
1040
- if args.resume_from_checkpoint:
1041
- if args.resume_from_checkpoint != "latest":
1042
- path = os.path.basename(args.resume_from_checkpoint)
1043
- else:
1044
- # Get the most recent checkpoint
1045
- dirs = os.listdir(args.output_dir)
1046
- dirs = [d for d in dirs if d.startswith("checkpoint")]
1047
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1048
- path = dirs[-1] if len(dirs) > 0 else None
1049
 
1050
- if path is None:
1051
- accelerator.print(
1052
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1053
- )
1054
- args.resume_from_checkpoint = None
1055
- initial_global_step = 0
1056
- else:
1057
- accelerator.print(f"Resuming from checkpoint {path}")
1058
- accelerator.load_state(os.path.join(args.output_dir, path))
1059
- global_step = int(path.split("-")[1])
1060
 
1061
- initial_global_step = global_step
1062
- first_epoch = global_step // num_update_steps_per_epoch
1063
 
1064
- else:
1065
- initial_global_step = 0
 
 
1066
 
1067
- progress_bar = tqdm(
1068
- range(0, args.max_train_steps),
1069
- initial=initial_global_step,
1070
- desc="Steps",
1071
- # Only show the progress bar once on each machine.
1072
- disable=not accelerator.is_local_main_process,
1073
  )
1074
 
1075
- for epoch in range(first_epoch, args.num_train_epochs):
1076
- train_loss = 0.0
1077
- for step, batch in enumerate(train_dataloader):
1078
- with accelerator.accumulate(unet):
1079
- # Sample noise that we'll add to the latents
1080
- model_input = batch["model_input"].to(accelerator.device)
1081
- noise = torch.randn_like(model_input)
1082
- if args.noise_offset:
1083
- # https://www.crosslabs.org//blog/diffusion-with-offset-noise
1084
- noise += args.noise_offset * torch.randn(
1085
- (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
1086
- )
1087
-
1088
- bsz = model_input.shape[0]
1089
- if args.timestep_bias_strategy == "none":
1090
- # Sample a random timestep for each image without bias.
1091
- timesteps = torch.randint(
1092
- 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1093
- )
1094
- else:
1095
- # Sample a random timestep for each image, potentially biased by the timestep weights.
1096
- # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
1097
- weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
1098
- model_input.device
1099
- )
1100
- timesteps = torch.multinomial(weights, bsz, replacement=True).long()
1101
-
1102
- # Add noise to the model input according to the noise magnitude at each timestep
1103
- # (this is the forward diffusion process)
1104
- noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
1105
-
1106
- # time ids
1107
- def compute_time_ids(original_size, crops_coords_top_left):
1108
- # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1109
- target_size = (args.resolution, args.resolution)
1110
- add_time_ids = list(original_size + crops_coords_top_left + target_size)
1111
- add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
1112
- return add_time_ids
1113
-
1114
- add_time_ids = torch.cat(
1115
- [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
1116
- )
1117
-
1118
- # Predict the noise residual
1119
- unet_added_conditions = {"time_ids": add_time_ids}
1120
- prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
1121
- pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
1122
- unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
1123
- model_pred = unet(
1124
- noisy_model_input,
1125
- timesteps,
1126
- prompt_embeds,
1127
- added_cond_kwargs=unet_added_conditions,
1128
- return_dict=False,
1129
- )[0]
1130
-
1131
- # Get the target for loss depending on the prediction type
1132
- if args.prediction_type is not None:
1133
- # set prediction_type of scheduler if defined
1134
- noise_scheduler.register_to_config(prediction_type=args.prediction_type)
1135
-
1136
- if noise_scheduler.config.prediction_type == "epsilon":
1137
- target = noise
1138
- elif noise_scheduler.config.prediction_type == "v_prediction":
1139
- target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1140
- elif noise_scheduler.config.prediction_type == "sample":
1141
- # We set the target to latents here, but the model_pred will return the noise sample prediction.
1142
- target = model_input
1143
- # We will have to subtract the noise residual from the prediction to get the target sample.
1144
- model_pred = model_pred - noise
1145
- else:
1146
- raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1147
-
1148
- if args.snr_gamma is None:
1149
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1150
- else:
1151
- # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
1152
- # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1153
- # This is discussed in Section 4.2 of the same paper.
1154
- snr = compute_snr(noise_scheduler, timesteps)
1155
- mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
1156
- dim=1
1157
- )[0]
1158
- if noise_scheduler.config.prediction_type == "epsilon":
1159
- mse_loss_weights = mse_loss_weights / snr
1160
- elif noise_scheduler.config.prediction_type == "v_prediction":
1161
- mse_loss_weights = mse_loss_weights / (snr + 1)
1162
-
1163
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1164
- loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1165
- loss = loss.mean()
1166
-
1167
- # Gather the losses across all processes for logging (if we use distributed training).
1168
- avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1169
- train_loss += avg_loss.item() / args.gradient_accumulation_steps
1170
-
1171
- # Backpropagate
1172
- accelerator.backward(loss)
1173
- if accelerator.sync_gradients:
1174
- params_to_clip = unet.parameters()
1175
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1176
- optimizer.step()
1177
- lr_scheduler.step()
1178
- optimizer.zero_grad()
1179
 
1180
- # Checks if the accelerator has performed an optimization step behind the scenes
1181
- if accelerator.sync_gradients:
1182
- if args.use_ema:
1183
- ema_unet.step(unet.parameters())
1184
- progress_bar.update(1)
1185
- global_step += 1
1186
- accelerator.log({"train_loss": train_loss}, step=global_step)
1187
- train_loss = 0.0
1188
-
1189
- # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
1190
- if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
1191
- if global_step % args.checkpointing_steps == 0:
1192
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1193
- if args.checkpoints_total_limit is not None:
1194
- checkpoints = os.listdir(args.output_dir)
1195
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1196
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1197
-
1198
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1199
- if len(checkpoints) >= args.checkpoints_total_limit:
1200
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1201
- removing_checkpoints = checkpoints[0:num_to_remove]
1202
-
1203
- logger.info(
1204
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1205
- )
1206
- logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1207
-
1208
- for removing_checkpoint in removing_checkpoints:
1209
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1210
- shutil.rmtree(removing_checkpoint)
1211
-
1212
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1213
- accelerator.save_state(save_path)
1214
- logger.info(f"Saved state to {save_path}")
1215
-
1216
- logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1217
- progress_bar.set_postfix(**logs)
1218
-
1219
- if global_step >= args.max_train_steps:
1220
- break
1221
-
1222
- if accelerator.is_main_process:
1223
- if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1224
- logger.info(
1225
- f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1226
- f" {args.validation_prompt}."
1227
- )
1228
- if args.use_ema:
1229
- # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1230
- ema_unet.store(unet.parameters())
1231
- ema_unet.copy_to(unet.parameters())
1232
-
1233
- # create pipeline
1234
- vae = AutoencoderKL.from_pretrained(
1235
- vae_path,
1236
- subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1237
- revision=args.revision,
1238
- variant=args.variant,
1239
- )
1240
- pipeline = StableDiffusionXLPipeline.from_pretrained(
1241
- args.pretrained_model_name_or_path,
1242
- vae=vae,
1243
- unet=accelerator.unwrap_model(unet),
1244
- revision=args.revision,
1245
- variant=args.variant,
1246
- torch_dtype=weight_dtype,
1247
- )
1248
- if args.prediction_type is not None:
1249
- scheduler_args = {"prediction_type": args.prediction_type}
1250
- pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1251
-
1252
- pipeline = pipeline.to(accelerator.device)
1253
- pipeline.set_progress_bar_config(disable=True)
1254
-
1255
- # run inference
1256
- generator = (
1257
- torch.Generator(device=accelerator.device).manual_seed(args.seed)
1258
- if args.seed is not None
1259
- else None
1260
- )
1261
- pipeline_args = {"prompt": args.validation_prompt}
1262
-
1263
- with autocast_ctx:
1264
- images = [
1265
- pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
1266
- for _ in range(args.num_validation_images)
1267
- ]
1268
-
1269
- for tracker in accelerator.trackers:
1270
- if tracker.name == "tensorboard":
1271
- np_images = np.stack([np.asarray(img) for img in images])
1272
- tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1273
- if tracker.name == "wandb":
1274
- tracker.log(
1275
- {
1276
- "validation": [
1277
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1278
- for i, image in enumerate(images)
1279
- ]
1280
- }
1281
- )
1282
-
1283
- del pipeline
1284
- if is_torch_npu_available():
1285
- torch_npu.npu.empty_cache()
1286
- elif torch.cuda.is_available():
1287
- torch.cuda.empty_cache()
1288
-
1289
- if args.use_ema:
1290
- # Switch back to the original UNet parameters.
1291
- ema_unet.restore(unet.parameters())
1292
-
1293
- accelerator.wait_for_everyone()
1294
- if accelerator.is_main_process:
1295
- unet = unwrap_model(unet)
1296
- if args.use_ema:
1297
- ema_unet.copy_to(unet.parameters())
1298
-
1299
- # Serialize pipeline.
1300
- vae = AutoencoderKL.from_pretrained(
1301
- vae_path,
1302
- subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1303
- revision=args.revision,
1304
- variant=args.variant,
1305
- torch_dtype=weight_dtype,
1306
- )
1307
- pipeline = StableDiffusionXLPipeline.from_pretrained(
1308
- args.pretrained_model_name_or_path,
1309
- unet=unet,
1310
- vae=vae,
1311
- revision=args.revision,
1312
- variant=args.variant,
1313
- torch_dtype=weight_dtype,
1314
- )
1315
- if args.prediction_type is not None:
1316
- scheduler_args = {"prediction_type": args.prediction_type}
1317
- pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1318
- pipeline.save_pretrained(args.output_dir)
1319
-
1320
- # run inference
1321
- images = []
1322
- if args.validation_prompt and args.num_validation_images > 0:
1323
- pipeline = pipeline.to(accelerator.device)
1324
- generator = (
1325
- torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
1326
- )
1327
-
1328
- with autocast_ctx:
1329
- images = [
1330
- pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1331
- for _ in range(args.num_validation_images)
1332
- ]
1333
-
1334
- for tracker in accelerator.trackers:
1335
- if tracker.name == "tensorboard":
1336
- np_images = np.stack([np.asarray(img) for img in images])
1337
- tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1338
- if tracker.name == "wandb":
1339
- tracker.log(
1340
- {
1341
- "test": [
1342
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1343
- for i, image in enumerate(images)
1344
- ]
1345
- }
1346
- )
1347
-
1348
- if args.push_to_hub:
1349
- save_model_card(
1350
- repo_id=repo_id,
1351
- images=images,
1352
- validation_prompt=args.validation_prompt,
1353
- base_model=args.pretrained_model_name_or_path,
1354
- dataset_name=args.dataset_name,
1355
- repo_folder=args.output_dir,
1356
- vae_path=args.pretrained_vae_model_name_or_path,
1357
- )
1358
- upload_folder(
1359
- repo_id=repo_id,
1360
- folder_path=args.output_dir,
1361
- commit_message="End of training",
1362
- ignore_patterns=["step_*", "epoch_*"],
1363
- )
1364
-
1365
- accelerator.end_training()
1366
 
 
1367
 
1368
- if __name__ == "__main__":
1369
- args = parse_args()
1370
- main(args)
 
1
+ # %%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
5
 
6
+ def load_model(model_name="datalab-to/chandra", device_id=0):
7
+ bnb_config = BitsAndBytesConfig(
8
+ load_in_4bit=True,
9
+ bnb_4bit_compute_dtype=torch.bfloat16,
10
+ bnb_4bit_quant_type="nf4",
11
+ bnb_4bit_use_double_quant=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ processor = AutoProcessor.from_pretrained(model_name)
 
15
 
16
+ model = AutoModelForVision2Seq.from_pretrained(
17
+ model_name,
18
+ quantization_config=bnb_config,
19
+ dtype=torch.bfloat16,
20
+ device_map={"": device_id},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
 
23
+ return processor, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # %%
26
+ def caption_batch(batch, processor, model):
27
+ images = batch["image"]
 
 
28
 
29
+ messages = [
30
+ {
31
+ "role": "user",
32
+ "content": [
33
+ {"type": "image", "image": image},
34
+ {
35
+ "type": "text",
36
+ "text": "Describe the image, and skip mentioning that it's illustrated or from anime.",
37
+ },
38
+ ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  }
40
+ for image in images
41
+ ]
42
 
43
+ inputs = processor.apply_chat_template(
44
+ messages,
45
+ tokenize=True,
46
+ add_generation_prompt=True,
47
+ return_dict=True,
48
+ return_tensors="pt",
49
+ ).to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ with torch.no_grad():
52
+ generated = model.generate(**inputs)
 
 
53
 
54
+ decoded = processor.batch_decode(generated)
55
+ captions = [d.split("<|im_start|>assistant\n")[-1] for d in decoded]
56
 
57
+ return {"image": images, "text": captions}
 
 
 
 
 
58
 
59
+ # %%
60
+ import datasets
61
+ from datasets import Dataset
62
+ from typing import cast
63
+ from concurrent.futures import ThreadPoolExecutor
64
 
 
 
 
 
 
65
 
66
+ input_dataset = "none-yet/anime-captions"
67
+ output_dataset = "nroggendorff/anime-captions"
 
 
68
 
69
+ loaded = datasets.load_dataset(input_dataset, split="train")
 
70
 
71
+ if isinstance(loaded, datasets.DatasetDict):
72
+ ds = cast(Dataset, loaded["train"])
73
+ else:
74
+ ds = cast(Dataset, loaded)
 
 
 
 
 
75
 
76
+ num_gpus = torch.cuda.device_count()
77
+ models = [load_model(device_id=i) for i in range(num_gpus)]
 
 
 
 
 
 
 
 
78
 
79
+ batch_size = 8
80
+ shard_size = len(ds) // num_gpus
 
 
 
 
 
 
 
 
81
 
 
 
82
 
83
+ def process_shard(shard_idx, processor, model):
84
+ start = shard_idx * shard_size
85
+ end = start + shard_size if shard_idx < num_gpus - 1 else len(ds)
86
+ shard = ds.select(range(start, end))
87
 
88
+ return shard.map(
89
+ lambda batch: caption_batch(batch, processor, model),
90
+ batched=True,
91
+ batch_size=batch_size,
92
+ remove_columns=shard.column_names,
 
93
  )
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ with ThreadPoolExecutor(max_workers=num_gpus) as executor:
97
+ futures = [
98
+ executor.submit(process_shard, i, proc, model)
99
+ for i, (proc, model) in enumerate(models)
100
+ ]
101
+ shards = [f.result() for f in futures]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ ds = datasets.concatenate_datasets(shards)
104
 
105
+ # %%
106
+ ds.push_to_hub(output_dataset, create_pr=True)