Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- diff --git a/modules/shared.py b/modules/shared.py
- index fb84afd..900f920 100644
- --- a/modules/shared.py
- +++ b/modules/shared.py
- @@ -261,6 +261,7 @@ options_templates.update(options_section(('training', "Training"), {
- "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
- "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
- "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
- + "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
- }))
- options_templates.update(options_section(('sd', "Stable Diffusion"), {
- diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
- index 17dfb22..b0a1d26 100644
- --- a/modules/textual_inversion/textual_inversion.py
- +++ b/modules/textual_inversion/textual_inversion.py
- @@ -214,6 +214,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
- filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
- + unload = shared.opts.unload_models_when_training
- if save_embedding_every > 0:
- embedding_dir = os.path.join(log_directory, "embeddings")
- @@ -238,6 +239,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
- with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
- + if unload:
- + shared.sd_model.first_stage_model.to(devices.cpu)
- hijack = sd_hijack.model_hijack
- @@ -303,6 +306,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
- if images_dir is not None and steps_done % create_image_every == 0:
- forced_filename = f'{embedding_name}-{steps_done}'
- last_saved_image = os.path.join(images_dir, forced_filename)
- +
- + shared.sd_model.first_stage_model.to(devices.device)
- +
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- @@ -330,6 +336,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
- processed = processing.process_images(p)
- image = processed.images[0]
- + if unload:
- + shared.sd_model.first_stage_model.to(devices.cpu)
- +
- shared.state.current_image = image
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
- diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
- index e712284..d679e6f 100644
- --- a/modules/textual_inversion/ui.py
- +++ b/modules/textual_inversion/ui.py
- @@ -25,8 +25,10 @@ def train_embedding(*args):
- assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
- + apply_optimizations = shared.opts.training_xattention_optimizations
- try:
- - sd_hijack.undo_optimizations()
- + if not apply_optimizations:
- + sd_hijack.undo_optimizations()
- embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
- @@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
- except Exception:
- raise
- finally:
- - sd_hijack.apply_optimizations()
- + if not apply_optimizations:
- + sd_hijack.apply_optimizations()
Add Comment
Please, Sign In to add comment