Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python
- # coding=utf-8
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- import argparse
- import copy
- import functools
- import logging
- import math
- import os
- import random
- import shutil
- from contextlib import nullcontext
- from pathlib import Path
- import accelerate
- import numpy as np
- import torch
- import torch.nn.functional as F
- import torch.utils.checkpoint
- import transformers
- from accelerate import Accelerator
- from accelerate.logging import get_logger
- from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
- from datasets import load_dataset
- from huggingface_hub import create_repo, upload_folder
- from packaging import version
- from PIL import Image
- from torchvision import transforms
- from tqdm.auto import tqdm
- from transformers import (
- AutoTokenizer,
- CLIPTextModel,
- T5EncoderModel,
- )
- import diffusers
- from diffusers import (
- AutoencoderKL,
- FlowMatchEulerDiscreteScheduler,
- FluxTransformer2DModel,
- )
- from diffusers.models.controlnet_flux import FluxControlNetModel
- from diffusers.optimization import get_scheduler
- from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
- from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory
- from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
- from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
- from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
- from diffusers.utils.torch_utils import is_compiled_module
- if is_wandb_available():
- import wandb
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
- check_min_version("0.32.0.dev0")
- logger = get_logger(__name__)
- if is_torch_npu_available():
- torch.npu.config.allow_internal_format = False
- def log_validation(
- vae, flux_transformer, flux_controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
- ):
- logger.info("Running validation... ")
- if not is_final_validation:
- flux_controlnet = accelerator.unwrap_model(flux_controlnet)
- pipeline = FluxControlNetPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- controlnet=flux_controlnet,
- transformer=flux_transformer,
- torch_dtype=torch.bfloat16,
- )
- else:
- flux_controlnet = FluxControlNetModel.from_pretrained(
- args.output_dir, torch_dtype=torch.bfloat16, variant=args.save_weight_dtype
- )
- pipeline = FluxControlNetPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- controlnet=flux_controlnet,
- transformer=flux_transformer,
- torch_dtype=torch.bfloat16,
- )
- pipeline.to(accelerator.device)
- pipeline.set_progress_bar_config(disable=True)
- if args.enable_xformers_memory_efficient_attention:
- pipeline.enable_xformers_memory_efficient_attention()
- if args.seed is None:
- generator = None
- else:
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
- if len(args.validation_image) == len(args.validation_prompt):
- validation_images = args.validation_image
- validation_prompts = args.validation_prompt
- elif len(args.validation_image) == 1:
- validation_images = args.validation_image * len(args.validation_prompt)
- validation_prompts = args.validation_prompt
- elif len(args.validation_prompt) == 1:
- validation_images = args.validation_image
- validation_prompts = args.validation_prompt * len(args.validation_image)
- else:
- raise ValueError(
- "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
- )
- image_logs = []
- if is_final_validation or torch.backends.mps.is_available():
- autocast_ctx = nullcontext()
- else:
- autocast_ctx = torch.autocast(accelerator.device.type)
- for validation_prompt, validation_image in zip(validation_prompts, validation_images):
- from diffusers.utils import load_image
- validation_image = load_image(validation_image)
- # maybe need to inference on 1024 to get a good image
- validation_image = validation_image.resize((args.resolution, args.resolution))
- images = []
- # pre calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
- prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
- validation_prompt, prompt_2=validation_prompt
- )
- for _ in range(args.num_validation_images):
- with autocast_ctx:
- # need to fix in pipeline_flux_controlnet
- image = pipeline(
- prompt_embeds=prompt_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- control_image=validation_image,
- num_inference_steps=28,
- controlnet_conditioning_scale=0.7,
- guidance_scale=3.5,
- generator=generator,
- ).images[0]
- image = image.resize((args.resolution, args.resolution))
- images.append(image)
- image_logs.append(
- {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
- )
- tracker_key = "test" if is_final_validation else "validation"
- for tracker in accelerator.trackers:
- if tracker.name == "tensorboard":
- for log in image_logs:
- images = log["images"]
- validation_prompt = log["validation_prompt"]
- validation_image = log["validation_image"]
- formatted_images = []
- formatted_images.append(np.asarray(validation_image))
- for image in images:
- formatted_images.append(np.asarray(image))
- formatted_images = np.stack(formatted_images)
- tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
- elif tracker.name == "wandb":
- formatted_images = []
- for log in image_logs:
- images = log["images"]
- validation_prompt = log["validation_prompt"]
- validation_image = log["validation_image"]
- formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
- for image in images:
- image = wandb.Image(image, caption=validation_prompt)
- formatted_images.append(image)
- tracker.log({tracker_key: formatted_images})
- else:
- logger.warning(f"image logging not implemented for {tracker.name}")
- del pipeline
- free_memory()
- return image_logs
- def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
- img_str = ""
- if image_logs is not None:
- img_str = "You can find some example images below.\n\n"
- for i, log in enumerate(image_logs):
- images = log["images"]
- validation_prompt = log["validation_prompt"]
- validation_image = log["validation_image"]
- validation_image.save(os.path.join(repo_folder, "image_control.png"))
- img_str += f"prompt: {validation_prompt}\n"
- images = [validation_image] + images
- make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
- img_str += f"\n"
- model_description = f"""
- # controlnet-{repo_id}
- These are controlnet weights trained on {base_model} with new type of conditioning.
- {img_str}
- ## License
- Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)
- """
- model_card = load_or_create_model_card(
- repo_id_or_path=repo_id,
- from_training=True,
- license="other",
- base_model=base_model,
- model_description=model_description,
- inference=True,
- )
- tags = [
- "flux",
- "flux-diffusers",
- "text-to-image",
- "diffusers",
- "controlnet",
- "diffusers-training",
- ]
- model_card = populate_model_card(model_card, tags=tags)
- model_card.save(os.path.join(repo_folder, "README.md"))
- def parse_args(input_args=None):
- parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
- parser.add_argument(
- "--pretrained_model_name_or_path",
- type=str,
- default=None,
- required=True,
- help="Path to pretrained model or model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--pretrained_vae_model_name_or_path",
- type=str,
- default=None,
- help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
- )
- parser.add_argument(
- "--controlnet_model_name_or_path",
- type=str,
- default=None,
- help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
- " If not specified controlnet weights are initialized from unet.",
- )
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
- parser.add_argument(
- "--revision",
- type=str,
- default=None,
- required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--tokenizer_name",
- type=str,
- default=None,
- help="Pretrained tokenizer name or path if not the same as model_name",
- )
- parser.add_argument(
- "--output_dir",
- type=str,
- default="controlnet-model",
- help="The output directory where the model predictions and checkpoints will be written.",
- )
- parser.add_argument(
- "--cache_dir",
- type=str,
- default=None,
- help="The directory where the downloaded models and datasets will be stored.",
- )
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
- parser.add_argument(
- "--resolution",
- type=int,
- default=512,
- help=(
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
- " resolution"
- ),
- )
- parser.add_argument(
- "--crops_coords_top_left_h",
- type=int,
- default=0,
- help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
- )
- parser.add_argument(
- "--crops_coords_top_left_w",
- type=int,
- default=0,
- help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
- )
- parser.add_argument(
- "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
- )
- parser.add_argument("--num_train_epochs", type=int, default=1)
- parser.add_argument(
- "--max_train_steps",
- type=int,
- default=None,
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
- )
- parser.add_argument(
- "--checkpointing_steps",
- type=int,
- default=500,
- help=(
- "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
- "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
- "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
- "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
- "instructions."
- ),
- )
- parser.add_argument(
- "--checkpoints_total_limit",
- type=int,
- default=None,
- help=("Max number of checkpoints to store."),
- )
- parser.add_argument(
- "--resume_from_checkpoint",
- type=str,
- default=None,
- help=(
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
- ),
- )
- parser.add_argument(
- "--gradient_accumulation_steps",
- type=int,
- default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.",
- )
- parser.add_argument(
- "--gradient_checkpointing",
- action="store_true",
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=5e-6,
- help="Initial learning rate (after the potential warmup period) to use.",
- )
- parser.add_argument(
- "--scale_lr",
- action="store_true",
- default=False,
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
- )
- parser.add_argument(
- "--lr_scheduler",
- type=str,
- default="constant",
- help=(
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
- ' "constant", "constant_with_warmup"]'
- ),
- )
- parser.add_argument(
- "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
- )
- parser.add_argument(
- "--lr_num_cycles",
- type=int,
- default=1,
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
- )
- parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
- parser.add_argument(
- "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
- )
- parser.add_argument(
- "--use_adafactor",
- action="store_true",
- help=(
- "Adafactor is a stochastic optimization method based on Adam that reduces memory usage while retaining"
- "the empirical benefits of adaptivity. This is achieved through maintaining a factored representation "
- "of the squared gradient accumulator across training steps."
- ),
- )
- parser.add_argument(
- "--dataloader_num_workers",
- type=int,
- default=0,
- help=(
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
- ),
- )
- parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
- parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
- parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
- parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
- parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
- parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
- parser.add_argument(
- "--hub_model_id",
- type=str,
- default=None,
- help="The name of the repository to keep in sync with the local `output_dir`.",
- )
- parser.add_argument(
- "--logging_dir",
- type=str,
- default="logs",
- help=(
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
- ),
- )
- parser.add_argument(
- "--allow_tf32",
- action="store_true",
- help=(
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
- ),
- )
- parser.add_argument(
- "--report_to",
- type=str,
- default="tensorboard",
- help=(
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
- ),
- )
- parser.add_argument(
- "--mixed_precision",
- type=str,
- default=None,
- choices=["no", "fp16", "bf16"],
- help=(
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
- ),
- )
- parser.add_argument(
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
- )
- parser.add_argument(
- "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
- )
- parser.add_argument(
- "--set_grads_to_none",
- action="store_true",
- help=(
- "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
- " behaviors, so disable this argument if it causes any problems. More info:"
- " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
- ),
- )
- parser.add_argument(
- "--dataset_name",
- type=str,
- default=None,
- help=(
- "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
- " or to a folder containing files that 🤗 Datasets can understand."
- ),
- )
- parser.add_argument(
- "--dataset_config_name",
- type=str,
- default=None,
- help="The config of the Dataset, leave as None if there's only one config.",
- )
- parser.add_argument(
- "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
- )
- parser.add_argument(
- "--conditioning_image_column",
- type=str,
- default="conditioning_image",
- help="The column of the dataset containing the controlnet conditioning image.",
- )
- parser.add_argument(
- "--caption_column",
- type=str,
- default="text",
- help="The column of the dataset containing a caption or a list of captions.",
- )
- parser.add_argument(
- "--max_train_samples",
- type=int,
- default=None,
- help=(
- "For debugging purposes or quicker training, truncate the number of training examples to this "
- "value if set."
- ),
- )
- parser.add_argument(
- "--proportion_empty_prompts",
- type=float,
- default=0,
- help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
- )
- parser.add_argument(
- "--validation_prompt",
- type=str,
- default=None,
- nargs="+",
- help=(
- "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
- " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
- " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
- ),
- )
- parser.add_argument(
- "--validation_image",
- type=str,
- default=None,
- nargs="+",
- help=(
- "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
- " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
- " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
- " `--validation_image` that will be used with all `--validation_prompt`s."
- ),
- )
- parser.add_argument(
- "--num_double_layers",
- type=int,
- default=4,
- help="Number of double layers in the controlnet (default: 4).",
- )
- parser.add_argument(
- "--num_single_layers",
- type=int,
- default=4,
- help="Number of single layers in the controlnet (default: 4).",
- )
- parser.add_argument(
- "--num_validation_images",
- type=int,
- default=2,
- help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
- )
- parser.add_argument(
- "--validation_steps",
- type=int,
- default=100,
- help=(
- "Run validation every X steps. Validation consists of running the prompt"
- " `args.validation_prompt` multiple times: `args.num_validation_images`"
- " and logging the images."
- ),
- )
- parser.add_argument(
- "--tracker_project_name",
- type=str,
- default="flux_train_controlnet",
- help=(
- "The `project_name` argument passed to Accelerator.init_trackers for"
- " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
- ),
- )
- parser.add_argument(
- "--jsonl_for_train",
- type=str,
- default=None,
- help="Path to the jsonl file containing the training data.",
- )
- parser.add_argument(
- "--guidance_scale",
- type=float,
- default=3.5,
- help="the guidance scale used for transformer.",
- )
- parser.add_argument(
- "--save_weight_dtype",
- type=str,
- default="fp32",
- choices=[
- "fp16",
- "bf16",
- "fp32",
- ],
- help=("Preserve precision type according to selected weight"),
- )
- parser.add_argument(
- "--weighting_scheme",
- type=str,
- default="logit_normal",
- choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
- help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
- )
- parser.add_argument(
- "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
- )
- parser.add_argument(
- "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
- )
- parser.add_argument(
- "--mode_scale",
- type=float,
- default=1.29,
- help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
- )
- parser.add_argument(
- "--enable_model_cpu_offload",
- action="store_true",
- help="Enable model cpu offload and save memory.",
- )
- if input_args is not None:
- args = parser.parse_args(input_args)
- else:
- args = parser.parse_args()
- if args.dataset_name is None and args.jsonl_for_train is None:
- raise ValueError("Specify either `--dataset_name` or `--jsonl_for_train`")
- if args.dataset_name is not None and args.jsonl_for_train is not None:
- raise ValueError("Specify only one of `--dataset_name` or `--jsonl_for_train`")
- if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
- raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
- if args.validation_prompt is not None and args.validation_image is None:
- raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
- if args.validation_prompt is None and args.validation_image is not None:
- raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
- if (
- args.validation_image is not None
- and args.validation_prompt is not None
- and len(args.validation_image) != 1
- and len(args.validation_prompt) != 1
- and len(args.validation_image) != len(args.validation_prompt)
- ):
- raise ValueError(
- "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
- " or the same number of `--validation_prompt`s and `--validation_image`s"
- )
- if args.resolution % 8 != 0:
- raise ValueError(
- "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
- )
- return args
- def get_train_dataset(args, accelerator):
- dataset = None
- if args.dataset_name is not None:
- # Downloading and loading a dataset from the hub.
- dataset = load_dataset(
- args.dataset_name,
- args.dataset_config_name,
- cache_dir=args.cache_dir,
- )
- if args.jsonl_for_train is not None:
- # load from json
- dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)
- dataset = dataset.flatten_indices()
- # Preprocessing the datasets.
- # We need to tokenize inputs and targets.
- column_names = dataset["train"].column_names
- # 6. Get the column names for input/target.
- if args.image_column is None:
- image_column = column_names[0]
- logger.info(f"image column defaulting to {image_column}")
- else:
- image_column = args.image_column
- if image_column not in column_names:
- raise ValueError(
- f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
- )
- if args.caption_column is None:
- caption_column = column_names[1]
- logger.info(f"caption column defaulting to {caption_column}")
- else:
- caption_column = args.caption_column
- if caption_column not in column_names:
- raise ValueError(
- f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
- )
- if args.conditioning_image_column is None:
- conditioning_image_column = column_names[2]
- logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
- else:
- conditioning_image_column = args.conditioning_image_column
- if conditioning_image_column not in column_names:
- raise ValueError(
- f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
- )
- with accelerator.main_process_first():
- train_dataset = dataset["train"].shuffle(seed=args.seed)
- if args.max_train_samples is not None:
- train_dataset = train_dataset.select(range(args.max_train_samples))
- return train_dataset
- def prepare_train_dataset(dataset, accelerator):
- image_transforms = transforms.Compose(
- [
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
- transforms.CenterCrop(args.resolution),
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
- conditioning_image_transforms = transforms.Compose(
- [
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
- transforms.CenterCrop(args.resolution),
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
- def preprocess_train(examples):
- images = [
- (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
- for image in examples[args.image_column]
- ]
- images = [image_transforms(image) for image in images]
- conditioning_images = [
- (image.convert("RGB") if not isinstance(image, str) else Image.open(image).convert("RGB"))
- for image in examples[args.conditioning_image_column]
- ]
- conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
- examples["pixel_values"] = images
- examples["conditioning_pixel_values"] = conditioning_images
- return examples
- with accelerator.main_process_first():
- dataset = dataset.with_transform(preprocess_train)
- return dataset
- def collate_fn(examples):
- pixel_values = torch.stack([example["pixel_values"] for example in examples])
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
- conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
- conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
- prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
- pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
- text_ids = torch.stack([torch.tensor(example["text_ids"]) for example in examples])
- return {
- "pixel_values": pixel_values,
- "conditioning_pixel_values": conditioning_pixel_values,
- "prompt_ids": prompt_ids,
- "unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids},
- }
- def main(args):
- if args.report_to == "wandb" and args.hub_token is not None:
- raise ValueError(
- "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
- )
- logging_out_dir = Path(args.output_dir, args.logging_dir)
- if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
- # due to pytorch#99272, MPS does not yet support bfloat16.
- raise ValueError(
- "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
- )
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))
- accelerator = Accelerator(
- gradient_accumulation_steps=args.gradient_accumulation_steps,
- mixed_precision=args.mixed_precision,
- log_with=args.report_to,
- project_config=accelerator_project_config,
- )
- # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.
- if torch.backends.mps.is_available():
- print("MPS is enabled. Disabling AMP.")
- accelerator.native_amp = False
- # Make one log on every process with the configuration for debugging.
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- # DEBUG, INFO, WARNING, ERROR, CRITICAL
- level=logging.INFO,
- )
- logger.info(accelerator.state, main_process_only=False)
- if accelerator.is_local_main_process:
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
- # If passed along, set the training seed now.
- if args.seed is not None:
- set_seed(args.seed)
- # Handle the repository creation
- if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
- if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
- ).repo_id
- # Load the tokenizers
- # load clip tokenizer
- tokenizer_one = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer",
- revision=args.revision,
- )
- # load t5 tokenizer
- tokenizer_two = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer_2",
- revision=args.revision,
- )
- # load clip text encoder
- text_encoder_one = CLIPTextModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
- )
- # load t5 text encoder
- text_encoder_two = T5EncoderModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
- )
- vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="vae",
- revision=args.revision,
- variant=args.variant,
- )
- flux_transformer = FluxTransformer2DModel.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="transformer",
- revision=args.revision,
- variant=args.variant,
- )
- if args.controlnet_model_name_or_path:
- logger.info("Loading existing controlnet weights")
- flux_controlnet = FluxControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
- else:
- logger.info("Initializing controlnet weights from transformer")
- # we can define the num_layers, num_single_layers,
- flux_controlnet = FluxControlNetModel.from_transformer(
- flux_transformer,
- attention_head_dim=flux_transformer.config["attention_head_dim"],
- num_attention_heads=flux_transformer.config["num_attention_heads"],
- num_layers=args.num_double_layers,
- num_single_layers=args.num_single_layers,
- )
- logger.info("all models loaded successfully")
- noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="scheduler",
- )
- noise_scheduler_copy = copy.deepcopy(noise_scheduler)
- vae.requires_grad_(False)
- flux_transformer.requires_grad_(False)
- text_encoder_one.requires_grad_(False)
- text_encoder_two.requires_grad_(False)
- flux_controlnet.train()
- # use some pipeline function
- flux_controlnet_pipeline = FluxControlNetPipeline(
- scheduler=noise_scheduler,
- vae=vae,
- text_encoder=text_encoder_one,
- tokenizer=tokenizer_one,
- text_encoder_2=text_encoder_two,
- tokenizer_2=tokenizer_two,
- transformer=flux_transformer,
- controlnet=flux_controlnet,
- )
- if args.enable_model_cpu_offload:
- flux_controlnet_pipeline.enable_model_cpu_offload()
- else:
- flux_controlnet_pipeline.to(accelerator.device)
- def unwrap_model(model):
- model = accelerator.unwrap_model(model)
- model = model._orig_mod if is_compiled_module(model) else model
- return model
- # `accelerate` 0.16.0 will have better support for customized saving
- if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
- def save_model_hook(models, weights, output_dir):
- if accelerator.is_main_process:
- i = len(weights) - 1
- while len(weights) > 0:
- weights.pop()
- model = models[i]
- sub_dir = "flux_controlnet"
- model.save_pretrained(os.path.join(output_dir, sub_dir))
- i -= 1
- def load_model_hook(models, input_dir):
- while len(models) > 0:
- # pop models so that they are not loaded again
- model = models.pop()
- # load diffusers style into model
- load_model = FluxControlNetModel.from_pretrained(input_dir, subfolder="flux_controlnet")
- model.register_to_config(**load_model.config)
- model.load_state_dict(load_model.state_dict())
- del load_model
- accelerator.register_save_state_pre_hook(save_model_hook)
- accelerator.register_load_state_pre_hook(load_model_hook)
- if args.enable_npu_flash_attention:
- if is_torch_npu_available():
- logger.info("npu flash attention enabled.")
- flux_transformer.enable_npu_flash_attention()
- else:
- raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
- if args.enable_xformers_memory_efficient_attention:
- if is_xformers_available():
- import xformers
- xformers_version = version.parse(xformers.__version__)
- if xformers_version == version.parse("0.0.16"):
- logger.warning(
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
- )
- flux_transformer.enable_xformers_memory_efficient_attention()
- flux_controlnet.enable_xformers_memory_efficient_attention()
- else:
- raise ValueError("xformers is not available. Make sure it is installed correctly")
- if args.gradient_checkpointing:
- flux_transformer.enable_gradient_checkpointing()
- flux_controlnet.enable_gradient_checkpointing()
- # Check that all trainable models are in full precision
- low_precision_error_string = (
- " Please make sure to always have all model weights in full float32 precision when starting training - even if"
- " doing mixed precision training, copy of the weights should still be float32."
- )
- if unwrap_model(flux_controlnet).dtype != torch.float32:
- raise ValueError(
- f"Controlnet loaded as datatype {unwrap_model(flux_controlnet).dtype}. {low_precision_error_string}"
- )
- # Enable TF32 for faster training on Ampere GPUs,
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
- if args.allow_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
- if args.scale_lr:
- args.learning_rate = (
- args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
- )
- # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
- )
- optimizer_class = bnb.optim.AdamW8bit
- else:
- optimizer_class = torch.optim.AdamW
- # Optimizer creation
- params_to_optimize = flux_controlnet.parameters()
- # use adafactor optimizer to save gpu memory
- if args.use_adafactor:
- from transformers import Adafactor
- optimizer = Adafactor(
- params_to_optimize,
- lr=args.learning_rate,
- scale_parameter=False,
- relative_step=False,
- # warmup_init=True,
- weight_decay=args.adam_weight_decay,
- )
- else:
- optimizer = optimizer_class(
- params_to_optimize,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- )
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
- # as these models are only used for inference, keeping weights in full precision is not required.
- weight_dtype = torch.float32
- if accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
- vae.to(accelerator.device, dtype=weight_dtype)
- flux_transformer.to(accelerator.device, dtype=weight_dtype)
- def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline, weight_dtype, is_train=True):
- prompt_batch = batch[args.caption_column]
- captions = []
- for caption in prompt_batch:
- if random.random() < proportion_empty_prompts:
- captions.append("")
- elif isinstance(caption, str):
- captions.append(caption)
- elif isinstance(caption, (list, np.ndarray)):
- # take a random caption if there are multiple
- captions.append(random.choice(caption) if is_train else caption[0])
- prompt_batch = captions
- prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(
- prompt_batch, prompt_2=prompt_batch
- )
- prompt_embeds = prompt_embeds.to(dtype=weight_dtype)
- pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)
- text_ids = text_ids.to(dtype=weight_dtype)
- # text_ids [512,3] to [bs,512,3]
- text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)
- return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
- train_dataset = get_train_dataset(args, accelerator).train_test_split(train_size=0.001, seed=42)["train"]
- text_encoders = [text_encoder_one, text_encoder_two]
- tokenizers = [tokenizer_one, tokenizer_two]
- compute_embeddings_fn = functools.partial(
- compute_embeddings,
- flux_controlnet_pipeline=flux_controlnet_pipeline,
- proportion_empty_prompts=args.proportion_empty_prompts,
- weight_dtype=weight_dtype,
- )
- dataset_args = {
- "dataset_name": args.dataset_name,
- "dataset_config_name": args.dataset_config_name,
- "image_column": args.image_column,
- "conditioning_image_column": args.conditioning_image_column,
- "caption_column": args.caption_column,
- "max_train_samples": args.max_train_samples,
- "proportion_empty_prompts": args.proportion_empty_prompts,
- "jsonl_for_train": args.jsonl_for_train,
- }
- logger.info("Start dataset caching...")
- with accelerator.main_process_first():
- from datasets.fingerprint import Hasher
- # fingerprint used by the cache for the other processes to load the result
- # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
- new_fingerprint = Hasher.hash(dataset_args)
- train_dataset = train_dataset.map(
- compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
- )
- del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
- free_memory()
- free_memory()
- logger.info("Dataset preparation is done!")
- # START LOADING ACTUAL MODELS FOR TRAINING
- logger.info("Starting to load actual models for training...")
- # Then get the training dataset ready to be passed to the dataloader.
- train_dataset = prepare_train_dataset(train_dataset, accelerator)
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset,
- shuffle=True,
- collate_fn=collate_fn,
- batch_size=args.train_batch_size,
- num_workers=args.dataloader_num_workers,
- )
- # Scheduler and math around the number of training steps.
- # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
- if args.max_train_steps is None:
- len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
- num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
- num_training_steps_for_scheduler = (
- args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
- )
- else:
- num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
- lr_scheduler = get_scheduler(
- args.lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
- num_cycles=args.lr_num_cycles,
- power=args.lr_power,
- )
- # Prepare everything with our `accelerator`.
- flux_controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- flux_controlnet, optimizer, train_dataloader, lr_scheduler
- )
- flux_controlnet.to(accelerator.device, dtype=weight_dtype)
- vae = accelerator.prepare(vae)
- flux_transformer = accelerator.prepare(flux_transformer)
- vae.to(accelerator.device, dtype=weight_dtype)
- flux_transformer.to(accelerator.device, dtype=weight_dtype)
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
- logger.warning(
- f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
- f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
- f"This inconsistency may result in the learning rate scheduler not functioning properly."
- )
- # Afterwards we recalculate our number of training epochs
- args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
- # We need to initialize the trackers we use, and also store our configuration.
- # The trackers initializes automatically on the main process.
- if accelerator.is_main_process:
- tracker_config = dict(vars(args))
- # tensorboard cannot handle list types for config
- tracker_config.pop("validation_prompt")
- tracker_config.pop("validation_image")
- accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
- # Train!
- total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
- logger.info("***** Running training *****")
- logger.info(f" Num examples = {len(train_dataset)}")
- logger.info(f" Num batches each epoch = {len(train_dataloader)}")
- logger.info(f" Num Epochs = {args.num_train_epochs}")
- logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
- logger.info(f" Total optimization steps = {args.max_train_steps}")
- global_step = 0
- first_epoch = 0
- # Potentially load in the weights and states from a previous save
- if args.resume_from_checkpoint:
- if args.resume_from_checkpoint != "latest":
- path = os.path.basename(args.resume_from_checkpoint)
- else:
- # Get the most recent checkpoint
- dirs = os.listdir(args.output_dir)
- dirs = [d for d in dirs if d.startswith("checkpoint")]
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
- path = dirs[-1] if len(dirs) > 0 else None
- if path is None:
- accelerator.print(
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
- )
- args.resume_from_checkpoint = None
- initial_global_step = 0
- else:
- accelerator.print(f"Resuming from checkpoint {path}")
- accelerator.load_state(os.path.join(args.output_dir, path))
- global_step = int(path.split("-")[1])
- initial_global_step = global_step
- first_epoch = global_step // num_update_steps_per_epoch
- else:
- initial_global_step = 0
- progress_bar = tqdm(
- range(0, args.max_train_steps),
- initial=initial_global_step,
- desc="Steps",
- # Only show the progress bar once on each machine.
- disable=not accelerator.is_local_main_process,
- )
- def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
- sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
- schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
- timesteps = timesteps.to(accelerator.device)
- step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
- sigma = sigmas[step_indices].flatten()
- while len(sigma.shape) < n_dim:
- sigma = sigma.unsqueeze(-1)
- return sigma
- image_logs = None
- for epoch in range(first_epoch, args.num_train_epochs):
- for step, batch in enumerate(train_dataloader):
- with accelerator.accumulate(flux_controlnet):
- # Convert images to latent space
- # vae encode
- pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
- pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()
- pixel_latents_tmp = (pixel_latents_tmp - unwrap_model(vae).config.shift_factor) * unwrap_model(vae).config.scaling_factor
- pixel_latents = FluxControlNetPipeline._pack_latents(
- pixel_latents_tmp,
- pixel_values.shape[0],
- pixel_latents_tmp.shape[1],
- pixel_latents_tmp.shape[2],
- pixel_latents_tmp.shape[3],
- )
- control_values = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
- control_latents = vae.encode(control_values).latent_dist.sample()
- control_latents = (control_latents - unwrap_model(vae).config.shift_factor) * unwrap_model(vae).config.scaling_factor
- control_image = FluxControlNetPipeline._pack_latents(
- control_latents,
- control_values.shape[0],
- control_latents.shape[1],
- control_latents.shape[2],
- control_latents.shape[3],
- )
- latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
- batch_size=pixel_latents_tmp.shape[0],
- height=pixel_latents_tmp.shape[2] // 2,
- width=pixel_latents_tmp.shape[3] // 2,
- device=pixel_values.device,
- dtype=pixel_values.dtype,
- )
- bsz = pixel_latents.shape[0]
- noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype)
- # Sample a random timestep for each image
- # for weighting schemes where we sample timesteps non-uniformly
- u = compute_density_for_timestep_sampling(
- weighting_scheme=args.weighting_scheme,
- batch_size=bsz,
- logit_mean=args.logit_mean,
- logit_std=args.logit_std,
- mode_scale=args.mode_scale,
- )
- indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
- timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)
- # Add noise according to flow matching.
- sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)
- noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise
- # handle guidance
- if unwrap_model(flux_transformer).config.guidance_embeds:
- guidance_vec = torch.full(
- (noisy_model_input.shape[0],),
- args.guidance_scale,
- device=noisy_model_input.device,
- dtype=weight_dtype,
- )
- else:
- guidance_vec = None
- controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
- hidden_states=noisy_model_input,
- controlnet_cond=control_image,
- timestep=timesteps / 1000,
- guidance=guidance_vec,
- pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
- encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
- txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
- img_ids=latent_image_ids,
- return_dict=False,
- )
- noise_pred = flux_transformer(
- hidden_states=noisy_model_input,
- timestep=timesteps / 1000,
- guidance=guidance_vec,
- pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
- encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
- controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples]
- if controlnet_block_samples is not None
- else None,
- controlnet_single_block_samples=[
- sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples
- ]
- if controlnet_single_block_samples is not None
- else None,
- txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
- img_ids=latent_image_ids,
- return_dict=False,
- )[0]
- loss = F.mse_loss(noise_pred.float(), (noise - pixel_latents).float(), reduction="mean")
- accelerator.backward(loss)
- # Check if the gradient of each model parameter contains NaN
- for name, param in flux_controlnet.named_parameters():
- if param.grad is not None and torch.isnan(param.grad).any():
- logger.error(f"Gradient for {name} contains NaN!")
- if accelerator.sync_gradients:
- params_to_clip = flux_controlnet.parameters()
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad(set_to_none=args.set_grads_to_none)
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
- # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
- if accelerator.distributed_type in [DistributedType.DEEPSPEED, DistributedType.FSDP] or accelerator.is_main_process:
- if global_step % args.checkpointing_steps == 0:
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
- if args.checkpoints_total_limit is not None:
- checkpoints = os.listdir(args.output_dir)
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
- if len(checkpoints) >= args.checkpoints_total_limit:
- num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
- removing_checkpoints = checkpoints[0:num_to_remove]
- logger.info(
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
- )
- logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
- for removing_checkpoint in removing_checkpoints:
- removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
- shutil.rmtree(removing_checkpoint)
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
- accelerator.save_state(save_path)
- logger.info(f"Saved state to {save_path}")
- if args.validation_prompt is not None and global_step % args.validation_steps == 0:
- image_logs = log_validation(
- vae=vae,
- flux_transformer=flux_transformer,
- flux_controlnet=flux_controlnet,
- args=args,
- accelerator=accelerator,
- weight_dtype=weight_dtype,
- step=global_step,
- )
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
- accelerator.log(logs, step=global_step)
- if global_step >= args.max_train_steps:
- break
- # Create the pipeline using using the trained modules and save it.
- accelerator.wait_for_everyone()
- if accelerator.is_main_process:
- flux_controlnet = unwrap_model(flux_controlnet)
- save_weight_dtype = torch.float32
- if args.save_weight_dtype == "fp16":
- save_weight_dtype = torch.float16
- elif args.save_weight_dtype == "bf16":
- save_weight_dtype = torch.bfloat16
- flux_controlnet.to(save_weight_dtype)
- if args.save_weight_dtype != "fp32":
- flux_controlnet.save_pretrained(args.output_dir, variant=args.save_weight_dtype)
- else:
- flux_controlnet.save_pretrained(args.output_dir)
- # Run a final round of validation.
- # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
- image_logs = None
- if args.validation_prompt is not None:
- image_logs = log_validation(
- vae=vae,
- flux_transformer=flux_transformer,
- flux_controlnet=None,
- args=args,
- accelerator=accelerator,
- weight_dtype=weight_dtype,
- step=global_step,
- is_final_validation=True,
- )
- if args.push_to_hub:
- save_model_card(
- repo_id,
- image_logs=image_logs,
- base_model=args.pretrained_model_name_or_path,
- repo_folder=args.output_dir,
- )
- upload_folder(
- repo_id=repo_id,
- folder_path=args.output_dir,
- commit_message="End of training",
- ignore_patterns=["step_*", "epoch_*"],
- )
- accelerator.end_training()
- if __name__ == "__main__":
- args = parse_args()
- main(args)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement