From 9aee793078ac3ab20bf1296756013458fabc78ad Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 14 Jun 2023 12:49:12 +0900 Subject: [PATCH 01/10] support arbitrary dataset for train_network.py --- library/train_util.py | 61 ++++++++++++++++++++++++++++++ train_network.py | 87 +++++++++++++++++++++++++++---------------- 2 files changed, 115 insertions(+), 33 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 844faca7..e1046d58 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1518,6 +1518,67 @@ def glob_images_pathlib(dir_path, recursive): return image_paths +class MinimalDataset(BaseDataset): + def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): + super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + + self.num_train_images = 0 # update in subclass + self.num_reg_images = 0 # update in subclass + self.datasets = [self] + self.batch_size = 1 # update in subclass + + self.subsets = [self] + self.num_repeats = 1 # update in subclass if needed + self.img_count = 1 # update in subclass if needed + self.bucket_info = {} + self.is_reg = False + self.image_dir = "dummy" # for metadata + + def is_latent_cacheable(self) -> bool: + return False + + def __len__(self): + raise NotImplementedError + + # override to avoid shuffling buckets + def set_current_epoch(self, epoch): + self.current_epoch = epoch + + def __getitem__(self, idx): + r""" + The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects. + + Returns: example like this: + + for i in range(batch_size): + image_key = ... # whatever hashable + image_keys.append(image_key) + + image = ... # PIL Image + img_tensor = self.image_transforms(img) + images.append(img_tensor) + + caption = ... # str + input_ids = self.get_input_ids(caption) + input_ids_list.append(input_ids) + + captions.append(caption) + + images = torch.stack(images, dim=0) + input_ids_list = torch.stack(input_ids_list, dim=0) + example = { + "images": images, + "input_ids": input_ids_list, + "captions": captions, # for debug_dataset + "latents": None, + "image_keys": image_keys, # for debug_dataset + "loss_weights": torch.ones(batch_size, dtype=torch.float32), + } + return example + """ + raise NotImplementedError + + # endregion # region モジュール入れ替え部 diff --git a/train_network.py b/train_network.py index b62aef7e..6f845b5a 100644 --- a/train_network.py +++ b/train_network.py @@ -92,42 +92,56 @@ def train(args): tokenizer = train_util.load_tokenizer(args) # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) - if use_user_config: - print(f"Loading dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + if use_user_config: + print(f"Loading dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) ) - ) - else: - if use_dreambooth_method: - print("Using DreamBooth method.") - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } else: - print("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + if use_dreambooth_method: + print("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + print("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + # use arbitrary dataset class + module = ".".join(args.dataset_class.split(".")[:-1]) + dataset_class = args.dataset_class.split(".")[-1] + module = importlib.import_module(module) + dataset_class = getattr(module, dataset_class) + train_dataset_group: train_util.MinimalDataset = dataset_class( + tokenizer, args.max_token_length, args.resolution, args.debug_dataset + ) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -185,6 +199,7 @@ def train(args): module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") print(f"all weights merged: {', '.join(args.base_weights)}") + # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) @@ -852,6 +867,12 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) + parser.add_argument( + "--dataset_class", + type=str, + default=None, + help="dataset class for arbitrary dataset / 任意のデータセットのクラス名", + ) return parser From d4ba37f54399ce81c3b1a3c1260c6dbf9ab447e9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 13:22:06 +0900 Subject: [PATCH 02/10] supprot dynamic prompt variants --- gen_img_diffusers.py | 295 +++++++++++++++++++++++++++++-------------- 1 file changed, 203 insertions(+), 92 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 33c40441..01001646 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -46,6 +46,7 @@ VGG( ) """ +import itertools import json from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable import glob @@ -2159,6 +2160,102 @@ def preprocess_mask(mask): return mask +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separater = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + print(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separater)) + else: + # make random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separater)) + + # make each prompt + if not enumerating: + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0]) + prompts.append(current) + else: + prompts = [prompt] + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: # enumerating + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement)) + prompts = new_prompts + for found, replacer in zip(founds, replacers): + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0]) + + return prompts + + # endregion @@ -2776,6 +2873,7 @@ def main(args): # seed指定時はseedを決めておく if args.seed is not None: + # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう random.seed(args.seed) predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] if len(predefined_seeds) == 1: @@ -3058,121 +3156,134 @@ def main(args): while not valid: print("\nType prompt:") try: - prompt = input() + raw_prompt = input() except EOFError: break - valid = len(prompt.strip().split(" --")[0].strip()) > 0 + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 if not valid: # EOF, end app break else: - prompt = prompt_list[prompt_index] + raw_prompt = prompt_list[prompt_index] - # parse prompt - width = args.W - height = args.H - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None + # sd-dynamic-prompts like variants: count is 1 or images_per_prompt or arbitrary + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - prompt_args = prompt.strip().split(" --") - prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + # repeat prompt + for prompt_index in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[prompt_index] if len(raw_prompts) > 1 else raw_prompts[0] - for parg in prompt_args[1:]: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - print(f"width: {width}") - continue + if prompt_index == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - print(f"height: {height}") - continue + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") - continue + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") - continue + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - print(f"scale: {scale}") - continue + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") - continue + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - print(f"strength: {strength}") - continue + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") - continue + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") - continue + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") - continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue - if seeds is not None: - # 数が足りないなら繰り返す - if len(seeds) < args.images_per_prompt: - seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds))) - seeds = seeds[: args.images_per_prompt] - else: - if predefined_seeds is not None: - seeds = predefined_seeds[-args.images_per_prompt :] - predefined_seeds = predefined_seeds[: -args.images_per_prompt] - elif args.iter_same_seed: - seeds = [iter_seed] * args.images_per_prompt + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) else: - seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)] + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + print("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) if args.interactive: - print(f"seed: {seeds}") + print(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None - init_image = mask_image = guide_image = None - for seed in seeds: # images_per_promptの数だけ # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する if init_images is not None: init_image = init_images[global_step % len(init_images)] From 624fbadea2b742f2bf32d82efeb332f24695881c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 19:19:16 +0900 Subject: [PATCH 03/10] fix dynamic prompt with from_file --- gen_img_diffusers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 01001646..7b5cee1f 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -3170,10 +3170,10 @@ def main(args): raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) # repeat prompt - for prompt_index in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[prompt_index] if len(raw_prompts) > 1 else raw_prompts[0] + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - if prompt_index == 0 or len(raw_prompts) > 1: + if pi == 0 or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing width = args.W height = args.H From f2989b36c2dfcd799460a22a90de5d7bdba8d2ec Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 20:37:01 +0900 Subject: [PATCH 04/10] fix typos, add comment --- gen_img_diffusers.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 7b5cee1f..acff1ea4 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2180,12 +2180,14 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): enumerating = False replacers = [] for found in founds: + # if "e$$" is found, enumerate all variants found_enumerating = found.group(2) is not None enumerating = enumerating or found_enumerating - separater = ", " if found.group(6) is None else found.group(6) + separator = ", " if found.group(6) is None else found.group(6) variants = found.group(7).split("|") + # parse count range count_range = found.group(4) if count_range is None: count_range = [1, 1] @@ -2206,7 +2208,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): count_range[1] = len(variants) if found_enumerating: - # make all combinations + # make function to enumerate all combinations def make_replacer_enum(vari, cr, sep): def replacer(): values = [] @@ -2217,9 +2219,9 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): return replacer - replacers.append(make_replacer_enum(variants, count_range, separater)) + replacers.append(make_replacer_enum(variants, count_range, separator)) else: - # make random combinations + # make function to choose random combinations def make_replacer_single(vari, cr, sep): def replacer(): count = random.randint(cr[0], cr[1]) @@ -2228,10 +2230,11 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): return replacer - replacers.append(make_replacer_single(variants, count_range, separater)) + replacers.append(make_replacer_single(variants, count_range, separator)) # make each prompt - if not enumerating: + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly prompts = [] for _ in range(repeat_count): current = prompt @@ -2239,16 +2242,21 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): current = current.replace(found.group(0), replacer()[0]) prompts.append(current) else: + # if enumerating, iterate all combinations for previous prompts prompts = [prompt] + for found, replacer in zip(founds, replacers): - if found.group(2) is not None: # enumerating + if found.group(2) is not None: + # make all combinations for existing prompts new_prompts = [] for current in prompts: replecements = replacer() for replecement in replecements: new_prompts.append(current.replace(found.group(0), replecement)) prompts = new_prompts + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts if found.group(2) is None: for i in range(len(prompts)): prompts[i] = prompts[i].replace(found.group(0), replacer()[0]) @@ -3166,7 +3174,8 @@ def main(args): else: raw_prompt = prompt_list[prompt_index] - # sd-dynamic-prompts like variants: count is 1 or images_per_prompt or arbitrary + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) # repeat prompt From 9806b00f74d1ee6be4d792e107cbd1b59b7addbb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 20:39:39 +0900 Subject: [PATCH 05/10] add arbitrary dataset feature to each script --- fine_tune.py | 54 ++++++++++++++------------ library/train_util.py | 17 +++++++- train_db.py | 40 ++++++++++--------- train_network.py | 14 +------ train_textual_inversion.py | 71 ++++++++++++++++++---------------- train_textual_inversion_XTI.py | 11 +++++- 6 files changed, 115 insertions(+), 92 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 201d4952..308f90ef 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -42,33 +42,37 @@ def train(args): tokenizer = train_util.load_tokenizer(args) - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) ) - ) - else: - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + else: + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/library/train_util.py b/library/train_util.py index e1046d58..4a25e00d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1579,6 +1579,15 @@ class MinimalDataset(BaseDataset): raise NotImplementedError +def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: + module = ".".join(args.dataset_class.split(".")[:-1]) + dataset_class = args.dataset_class.split(".")[-1] + module = importlib.import_module(module) + dataset_class = getattr(module, dataset_class) + train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset) + return train_dataset_group + + # endregion # region モジュール入れ替え部 @@ -2455,7 +2464,6 @@ def add_dataset_arguments( default=1, help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", ) - parser.add_argument( "--token_warmup_step", type=float, @@ -2463,6 +2471,13 @@ def add_dataset_arguments( help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) + parser.add_argument( + "--dataset_class", + type=str, + default=None, + help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)", + ) + if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに diff --git a/train_db.py b/train_db.py index c81a092d..115855c1 100644 --- a/train_db.py +++ b/train_db.py @@ -46,26 +46,30 @@ def train(args): tokenizer = train_util.load_tokenizer(args) - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) ) - ) - else: - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } + else: + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_network.py b/train_network.py index 6f845b5a..abec3d41 100644 --- a/train_network.py +++ b/train_network.py @@ -135,13 +135,7 @@ def train(args): train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: # use arbitrary dataset class - module = ".".join(args.dataset_class.split(".")[:-1]) - dataset_class = args.dataset_class.split(".")[-1] - module = importlib.import_module(module) - dataset_class = getattr(module, dataset_class) - train_dataset_group: train_util.MinimalDataset = dataset_class( - tokenizer, args.max_token_length, args.resolution, args.debug_dataset - ) + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -867,12 +861,6 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) - parser.add_argument( - "--dataset_class", - type=str, - default=None, - help="dataset class for arbitrary dataset / 任意のデータセットのクラス名", - ) return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 8be0703d..48713fc1 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -153,43 +153,46 @@ def train(args): print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) - if args.dataset_config is not None: - print(f"Load dataset config from {args.dataset_config}") - user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "reg_data_dir", "in_json"] - if any(getattr(args, attr) is not None for attr in ignored): - print( - "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( - ", ".join(ignored) + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + if args.dataset_config is not None: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "reg_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) ) - ) - else: - use_dreambooth_method = args.in_json is None - if use_dreambooth_method: - print("Use DreamBooth method.") - user_config = { - "datasets": [ - {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} - ] - } else: - print("Train with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + use_dreambooth_method = args.in_json is None + if use_dreambooth_method: + print("Use DreamBooth method.") + user_config = { + "datasets": [ + {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} + ] + } + else: + print("Train with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer) current_epoch = Value("i", 0) current_step = Value("i", 0) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7b734f28..bf7d5bb0 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -20,7 +20,13 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction +from library.custom_train_functions import ( + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -88,6 +94,9 @@ def train(args): print( "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" ) + assert ( + args.dataset_class is None + ), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません" cache_latents = args.cache_latents From f0bb3ae825efe6720f10301ee788072542b2e3ac Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 20:56:12 +0900 Subject: [PATCH 06/10] add an option to disable controlnet in 2nd stage --- gen_img_diffusers.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index acff1ea4..93a876ab 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -615,11 +615,15 @@ class PipelineLike: # ControlNet self.control_nets: List[ControlNetInfo] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない # Textual Inversion def add_token_replacement(self, target_token_id, rep_token_ids): self.token_replacements[target_token_id] = rep_token_ids + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + def replace_token(self, tokens, layer=None): new_tokens = [] for token in tokens: @@ -1112,7 +1116,7 @@ class PipelineLike: latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - if self.control_nets: + if self.control_nets and self.control_net_enabled: if reginonal_network: num_sub_and_neg_prompts = len(text_embeddings) // batch_size text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt @@ -2233,7 +2237,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): replacers.append(make_replacer_single(variants, count_range, separator)) # make each prompt - if not enumerating: + if not enumerating: # if not enumerating, repeat the prompt, replace each variant randomly prompts = [] for _ in range(repeat_count): @@ -2254,7 +2258,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): for replecement in replecements: new_prompts.append(current.replace(found.group(0), replecement)) prompts = new_prompts - + for found, replacer in zip(founds, replacers): # make random selection for existing prompts if found.group(2) is None: @@ -2933,6 +2937,8 @@ def main(args): ext.num_sub_prompts, ) batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する @@ -2976,6 +2982,9 @@ def main(args): batch_2nd.append(bd_2nd) batch = batch_2nd + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + # このバッチの情報を取り出す ( return_latents, @@ -3574,6 +3583,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) parser.add_argument( "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" From e97d67a68121df2ec57270d131c76ec8cb2e312d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Thu, 15 Jun 2023 20:12:53 +0800 Subject: [PATCH 07/10] Support for Prodigy(Dadapt variety for Dylora) (#585) * Update train_util.py for DAdaptLion * Update train_README-zh.md for dadaptlion * Update train_README-ja.md for DAdaptLion * add DAdatpt V3 * Alignment * Update train_util.py for experimental * Update train_util.py V3 * Update train_README-zh.md * Update train_README-ja.md * Update train_util.py fix * Update train_util.py * support Prodigy * add lower --- docs/train_README-ja.md | 1 + docs/train_README-zh.md | 3 ++- fine_tune.py | 2 +- library/train_util.py | 32 ++++++++++++++++++++++++++++++++ train_db.py | 2 +- train_network.py | 4 ++-- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 8 files changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index b64b1808..158363b3 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -622,6 +622,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - DAdaptAdanIP : 引数は同上 - DAdaptLion : 引数は同上 - DAdaptSGD : 引数は同上 + - Prodigy : https://github.com/konstmish/prodigy - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任意のオプティマイザ diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index 678832d2..454d5456 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -555,9 +555,10 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - DAdaptAdam : 参数同上 - DAdaptAdaGrad : 参数同上 - DAdaptAdan : 参数同上 - - DAdaptAdanIP : 引数は同上 + - DAdaptAdanIP : 参数同上 - DAdaptLion : 参数同上 - DAdaptSGD : 参数同上 + - Prodigy : https://github.com/konstmish/prodigy - AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules) - 任何优化器 diff --git a/fine_tune.py b/fine_tune.py index 308f90ef..d0013d53 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -397,7 +397,7 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/library/train_util.py b/library/train_util.py index 4a25e00d..5b5d99ac 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2808,6 +2808,38 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "Prodigy".lower(): + # Prodigy + # check Prodigy is installed + try: + import prodigyopt + except ImportError: + raise ImportError("No Prodigy / Prodigy がインストールされていないようです") + + # check lr and lr_count, and print warning + actual_lr = lr + lr_count = 1 + if type(trainable_params) == list and type(trainable_params[0]) == dict: + lrs = set() + actual_lr = trainable_params[0].get("lr", actual_lr) + for group in trainable_params: + lrs.add(group.get("lr", actual_lr)) + lr_count = len(lrs) + + if actual_lr <= 0.1: + print( + f"learning rate is too low. If using Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + ) + print("recommend option: lr=1.0 / 推奨は1.0です") + if lr_count > 1: + print( + f"when multiple learning rates are specified with Prodigy (e.g. for Text Encoder and U-Net), only the first one will take effect / Prodigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + ) + + print(f"use Prodigy optimizer | {optimizer_kwargs}") + optimizer_class = prodigyopt.Prodigy + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "Adafactor".lower(): # 引数を確認して適宜補正する if "relative_step" not in optimizer_kwargs: diff --git a/train_db.py b/train_db.py index 115855c1..927e79de 100644 --- a/train_db.py +++ b/train_db.py @@ -384,7 +384,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_network.py b/train_network.py index abec3d41..da0ca1c9 100644 --- a/train_network.py +++ b/train_network.py @@ -57,7 +57,7 @@ def generate_step_logs( logs["lr/textencoder"] = float(lrs[0]) logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value of unet. + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet. logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] else: idx = 0 @@ -67,7 +67,7 @@ def generate_step_logs( for i in range(idx, len(lrs)): logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()): + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 48713fc1..d08251e1 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -476,7 +476,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index bf7d5bb0..f44d565c 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -515,7 +515,7 @@ def train(args): current_loss = loss.detach().item() if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) From 5845de7d7c6c9d8dd6123e7b29f39302e8a8140a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 21:47:37 +0900 Subject: [PATCH 08/10] common lr checking for dadaptation and prodigy --- library/train_util.py | 114 +++++++++++++++++------------------------- 1 file changed, 47 insertions(+), 67 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5b5d99ac..acfb503b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2752,15 +2752,7 @@ def get_optimizer(args, trainable_params): optimizer_class = torch.optim.SGD optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs) - elif optimizer_type.startswith("DAdapt".lower()): - # DAdaptation family - # check dadaptation is installed - try: - import dadaptation - import dadaptation.experimental as experimental - except ImportError: - raise ImportError("No dadaptation / dadaptation がインストールされていないようです") - + elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower(): # check lr and lr_count, and print warning actual_lr = lr lr_count = 1 @@ -2773,72 +2765,60 @@ def get_optimizer(args, trainable_params): if actual_lr <= 0.1: print( - f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" + f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}" ) print("recommend option: lr=1.0 / 推奨は1.0です") if lr_count > 1: print( - f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" + f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" ) - # set optimizer - if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): - optimizer_class = experimental.DAdaptAdamPreprint - print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdaGrad".lower(): - optimizer_class = dadaptation.DAdaptAdaGrad - print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdam".lower(): - optimizer_class = dadaptation.DAdaptAdam - print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdan".lower(): - optimizer_class = dadaptation.DAdaptAdan - print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptAdanIP".lower(): - optimizer_class = experimental.DAdaptAdanIP - print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptLion".lower(): - optimizer_class = dadaptation.DAdaptLion - print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") - elif optimizer_type == "DAdaptSGD".lower(): - optimizer_class = dadaptation.DAdaptSGD - print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + if optimizer_type.startswith("DAdapt".lower()): + # DAdaptation family + # check dadaptation is installed + try: + import dadaptation + import dadaptation.experimental as experimental + except ImportError: + raise ImportError("No dadaptation / dadaptation がインストールされていないようです") + + # set optimizer + if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower(): + optimizer_class = experimental.DAdaptAdamPreprint + print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdaGrad".lower(): + optimizer_class = dadaptation.DAdaptAdaGrad + print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdam".lower(): + optimizer_class = dadaptation.DAdaptAdam + print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdan".lower(): + optimizer_class = dadaptation.DAdaptAdan + print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptAdanIP".lower(): + optimizer_class = experimental.DAdaptAdanIP + print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptLion".lower(): + optimizer_class = dadaptation.DAdaptLion + print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}") + elif optimizer_type == "DAdaptSGD".lower(): + optimizer_class = dadaptation.DAdaptSGD + print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) else: - raise ValueError(f"Unknown optimizer type: {optimizer_type}") + # Prodigy + # check Prodigy is installed + try: + import prodigyopt + except ImportError: + raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - - elif optimizer_type == "Prodigy".lower(): - # Prodigy - # check Prodigy is installed - try: - import prodigyopt - except ImportError: - raise ImportError("No Prodigy / Prodigy がインストールされていないようです") - - # check lr and lr_count, and print warning - actual_lr = lr - lr_count = 1 - if type(trainable_params) == list and type(trainable_params[0]) == dict: - lrs = set() - actual_lr = trainable_params[0].get("lr", actual_lr) - for group in trainable_params: - lrs.add(group.get("lr", actual_lr)) - lr_count = len(lrs) - - if actual_lr <= 0.1: - print( - f"learning rate is too low. If using Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}" - ) - print("recommend option: lr=1.0 / 推奨は1.0です") - if lr_count > 1: - print( - f"when multiple learning rates are specified with Prodigy (e.g. for Text Encoder and U-Net), only the first one will take effect / Prodigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}" - ) - - print(f"use Prodigy optimizer | {optimizer_kwargs}") - optimizer_class = prodigyopt.Prodigy - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + print(f"use Prodigy optimizer | {optimizer_kwargs}") + optimizer_class = prodigyopt.Prodigy + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "Adafactor".lower(): # 引数を確認して適宜補正する From 18156bf2a18f29e56d8f7dbb9de71d09399dde1d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 22:22:12 +0900 Subject: [PATCH 09/10] fix same replacement multiple times in dyn prompt --- gen_img_diffusers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 93a876ab..ffb79aa3 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2243,7 +2243,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): for _ in range(repeat_count): current = prompt for found, replacer in zip(founds, replacers): - current = current.replace(found.group(0), replacer()[0]) + current = current.replace(found.group(0), replacer()[0], 1) prompts.append(current) else: # if enumerating, iterate all combinations for previous prompts @@ -2256,14 +2256,14 @@ def handle_dynamic_prompt_variants(prompt, repeat_count): for current in prompts: replecements = replacer() for replecement in replecements: - new_prompts.append(current.replace(found.group(0), replecement)) + new_prompts.append(current.replace(found.group(0), replecement, 1)) prompts = new_prompts for found, replacer in zip(founds, replacers): # make random selection for existing prompts if found.group(2) is None: for i in range(len(prompts)): - prompts[i] = prompts[i].replace(found.group(0), replacer()[0]) + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) return prompts From 5d1b54de45c142261d7d93467d94ef14e369188d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 15 Jun 2023 22:27:47 +0900 Subject: [PATCH 10/10] update readme --- README.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/README.md b/README.md index 8234a89e..e6202bae 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,42 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 15 Jun. 2023, 2023/06/15 + +- Prodigy optimizer is supported in each training script. It is a member of D-Adaptation and is effective for DyLoRA training. [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) Please see the PR for details. Thanks to sdbds! + - Install the package with `pip install prodigyopt`. Then specify the option like `--optimizer_type="prodigy"`. +- Arbitrary Dataset is supported in each training script (except XTI). You can use it by defining a Dataset class that returns images and captions. + - Prepare a Python script and define a class that inherits `train_util.MinimalDataset`. Then specify the option like `--dataset_class package.module.DatasetClass` in each training script. + - Please refer to `MinimalDataset` for implementation. I will prepare a sample later. +- The following features have been added to the generation script. + - Added an option `--highres_fix_disable_control_net` to disable ControlNet in the 2nd stage of Highres. Fix. Please try it if the image is disturbed by some ControlNet such as Canny. + - Added Variants similar to sd-dynamic-propmpts in the prompt. + - If you specify `{spring|summer|autumn|winter}`, one of them will be randomly selected. + - If you specify `{2$$chocolate|vanilla|strawberry}`, two of them will be randomly selected. + - If you specify `{1-2$$ and $$chocolate|vanilla|strawberry}`, one or two of them will be randomly selected and connected by ` and `. + - You can specify the number of candidates in the range `0-2`. You cannot omit one side like `-2` or `1-`. + - It can also be specified for the prompt option. + - If you specify `e` or `E`, all candidates will be selected and the prompt will be repeated multiple times (`--images_per_prompt` is ignored). It may be useful for creating X/Y plots. + - You can also specify `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`. In this case, 15 prompts will be generated with 5*3. + - There is no weighting function. + +- 各学習スクリプトでProdigyオプティマイザがサポートされました。D-Adaptationの仲間でDyLoRAの学習に有効とのことです。 [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) 詳細はPRをご覧ください。sdbds氏に感謝します。 + - `pip install prodigyopt` としてパッケージをインストールしてください。また `--optimizer_type="prodigy"` のようにオプションを指定します。 +- 各学習スクリプトで任意のDatasetをサポートしました(XTIを除く)。画像とキャプションを返すDatasetクラスを定義することで、学習スクリプトから利用できます。 + - Pythonスクリプトを用意し、`train_util.MinimalDataset`を継承するクラスを定義してください。そして各学習スクリプトのオプションで `--dataset_class package.module.DatasetClass` のように指定してください。 + - 実装方法は `MinimalDataset` を参考にしてください。のちほどサンプルを用意します。 +- 生成スクリプトに以下の機能追加を行いました。 + - Highres. Fixの2nd stageでControlNetを無効化するオプション `--highres_fix_disable_control_net` を追加しました。Canny等一部のControlNetで画像が乱れる場合にお試しください。 + - プロンプトでsd-dynamic-propmptsに似たVariantをサポートしました。 + - `{spring|summer|autumn|winter}` のように指定すると、いずれかがランダムに選択されます。 + - `{2$$chocolate|vanilla|strawberry}` のように指定すると、いずれか2個がランダムに選択されます。 + - `{1-2$$ and $$chocolate|vanilla|strawberry}` のように指定すると、1個か2個がランダムに選択され ` and ` で接続されます。 + - 個数のレンジ指定では`0-2`のように0個も指定可能です。`-2`や`1-`のような片側の省略はできません。 + - プロンプトオプションに対しても指定可能です。 + - `{e$$chocolate|vanilla|strawberry}` のように`e`または`E`を指定すると、すべての候補が選択されプロンプトが複数回繰り返されます(`--images_per_prompt`は無視されます)。X/Y plotの作成に便利かもしれません。 + - `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`のような指定も可能です。この場合、5*3で15回のプロンプトが生成されます。 + - Weightingの機能はありません。 + ### 8 Jun. 2023, 2023/06/08 - Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training.