From 1567ce1e1777c169bd59237f1f722b2e9722bbf3 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 3 Feb 2024 20:46:31 +0800 Subject: [PATCH 1/4] Enable distributed sample image generation on multi-GPU enviroment (#1061) * Update train_util.py Modifying to attempt enable multi GPU inference * Update train_util.py additional VRAM checking, refactor check_vram_usage to return string for use with accelerator.print * Update train_network.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py remove sample image debug outputs * Update train_util.py * Update train_util.py * Update train_network.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_network.py * Update train_util.py * Update train_network.py * Update train_network.py * Update train_network.py * Cleanup of debugging outputs * adopt more elegant coding Co-authored-by: Aarni Koskela * Update train_util.py Fix leftover debugging code attempt to refactor inference into separate function * refactor in function generate_per_device_prompt_list() generation of distributed prompt list * Clean up missing variables * fix syntax error * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * true random sample image generation update code to reinitialize random seed to true random if seed was set * true random sample image generation * simplify per process prompt * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_network.py * Update train_network.py * Update train_network.py --------- Co-authored-by: Aarni Koskela --- library/train_util.py | 208 +++++++++++++++++++++++------------------- 1 file changed, 115 insertions(+), 93 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ba428e50..3e6125f0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,7 +19,7 @@ from typing import ( Tuple, Union, ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState import gc import glob import math @@ -4636,7 +4636,6 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict - def sample_images_common( pipe_class, accelerator: Accelerator, @@ -4654,6 +4653,7 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ + if steps == 0: if not args.sample_at_first: return @@ -4668,13 +4668,15 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return + distributed_state = PartialState() #testing implementation of multi gpu distributed inference + print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return org_vae_device = vae.device # CPUにいるはず - vae.to(device) + vae.to(distributed_state.device) # unwrap unet and text_encoder(s) unet = accelerator.unwrap_model(unet) @@ -4700,12 +4702,11 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - schedulers: dict = {} + # schedulers: dict = {} cannot find where this is used default_scheduler = get_my_scheduler( sample_sampler=args.sample_sampler, v_parameterization=args.v_parameterization, ) - schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, @@ -4718,114 +4719,135 @@ def sample_images_common( requires_safety_checker=False, clip_skip=args.clip_skip, ) - pipeline.to(device) - + pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) + + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement) rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + # True random sample image generation + torch.seed() + torch.cuda.seed() with torch.no_grad(): - # with accelerator.autocast(): - for i, prompt_dict in enumerate(prompts): - if not accelerator.is_main_process: - continue + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet) - if isinstance(prompt_dict, str): - prompt_dict = line_to_prompt_dict(prompt_dict) - - assert isinstance(prompt_dict, dict) - negative_prompt = prompt_dict.get("negative_prompt") - sample_steps = prompt_dict.get("sample_steps", 30) - width = prompt_dict.get("width", 512) - height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 7.5) - seed = prompt_dict.get("seed") - controlnet_image = prompt_dict.get("controlnet_image") - prompt: str = prompt_dict.get("prompt", "") - sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scheduler = schedulers.get(sampler_name) - if scheduler is None: - scheduler = get_my_scheduler( - sample_sampler=sampler_name, - v_parameterization=args.v_parameterization, - ) - schedulers[sampler_name] = scheduler - pipeline.scheduler = scheduler - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - print(f"sample_sampler: {sampler_name}") - if seed is not None: - print(f"seed: {seed}") - with accelerator.autocast(): - latents = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ) - - image = pipeline.latents_to_image(latents)[0] - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) - - image.save(os.path.join(save_dir, img_filename)) - - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") - try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") - - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + torch.set_rng_state(rng_state) if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) +def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None): + + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [[] for i in range(num_of_processes)] + for i, prompt in enumerate(prompts): + if isinstance(prompt, str): + prompt = line_to_prompt_dict(prompt) + assert isinstance(prompt, dict) + prompt.pop("subset", None) # Clean up subset key + prompt["enum"] = i + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + if prompt_replacement is not None: + prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1]) + if prompt["negative_prompt"] is not None: + prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1]) + # Refactor prompt replacement to here in order to simplify sample_image_inference function. + per_process_prompts[i % num_of_processes].append(prompt) + return per_process_prompts +def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None): + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scheduler = get_my_scheduler( + sample_sampler=sampler_name, + v_parameterization=args.v_parameterization, + ) + pipeline.scheduler = scheduler + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"\nprompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + print(f"sample_sampler: {sampler_name}") + if seed is not None: + print(f"seed: {seed}") + with accelerator.autocast(): + latents = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ) + image = pipeline.latents_to_image(latents)[0] + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + i: int = prompt_dict["enum"] + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" + ) + + image.save(os.path.join(save_dir, img_filename)) + if seed is not None: + torch.seed() + torch.cuda.seed() + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass # endregion + + + # region 前処理用 From 11aced35005221c05920e33658ceba46fc0e4272 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Feb 2024 22:25:29 +0900 Subject: [PATCH 2/4] simplify multi-GPU sample generation --- library/train_util.py | 94 ++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 3e6125f0..177fae55 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4668,13 +4668,13 @@ def sample_images_common( if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch return - distributed_state = PartialState() #testing implementation of multi gpu distributed inference - print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return + distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here + org_vae_device = vae.device # CPUにいるはず vae.to(distributed_state.device) @@ -4686,10 +4686,6 @@ def sample_images_common( text_encoder = accelerator.unwrap_model(text_encoder) # read prompts - - # with open(args.sample_prompts, "rt", encoding="utf-8") as f: - # prompts = f.readlines() - if args.sample_prompts.endswith(".txt"): with open(args.sample_prompts, "r", encoding="utf-8") as f: lines = f.readlines() @@ -4722,22 +4718,39 @@ def sample_images_common( pipeline.to(distributed_state.device) save_dir = args.output_dir + "/sample" os.makedirs(save_dir, exist_ok=True) - - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement) + # preprocess prompts + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + + # save random state to restore later rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - # True random sample image generation - torch.seed() - torch.cuda.seed() - - with torch.no_grad(): - with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: - for prompt_dict in prompt_dict_lists[0]: - sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet) - + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support + + if distributed_state.num_processes <= 1: + # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. + with torch.no_grad(): + for prompt_dict in prompts: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) + else: + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i::distributed_state.num_processes]) + + with torch.no_grad(): + with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists[0]: + sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) # clear pipeline and cache to reduce vram usage del pipeline @@ -4750,27 +4763,7 @@ def sample_images_common( torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) -def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None): - - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = [[] for i in range(num_of_processes)] - for i, prompt in enumerate(prompts): - if isinstance(prompt, str): - prompt = line_to_prompt_dict(prompt) - assert isinstance(prompt, dict) - prompt.pop("subset", None) # Clean up subset key - prompt["enum"] = i - # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. - if prompt_replacement is not None: - prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1]) - if prompt["negative_prompt"] is not None: - prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1]) - # Refactor prompt replacement to here in order to simplify sample_image_inference function. - per_process_prompts[i % num_of_processes].append(prompt) - return per_process_prompts - -def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None): +def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=None): assert isinstance(prompt_dict, dict) negative_prompt = prompt_dict.get("negative_prompt") sample_steps = prompt_dict.get("sample_steps", 30) @@ -4781,10 +4774,19 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) - + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) + else: + # True random sample image generation + torch.seed() + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -4819,7 +4821,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet_image=controlnet_image, ) image = pipeline.latents_to_image(latents)[0] + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list + # but adding 'enum' to the filename should be enough + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" seed_suffix = "" if seed is None else f"_{seed}" @@ -4827,11 +4832,8 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p img_filename = ( f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png" ) - image.save(os.path.join(save_dir, img_filename)) - if seed is not None: - torch.seed() - torch.cuda.seed() + # wandb有効時のみログを送信 try: wandb_tracker = accelerator.get_tracker("wandb") From 2f9a34429729c8b44c6f04cb27656d578ecb9420 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Feb 2024 23:26:57 +0900 Subject: [PATCH 3/4] fix typo --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 177fae55..1377997c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4741,7 +4741,7 @@ def sample_images_common( for prompt_dict in prompts: sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet) else: - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available) + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. per_process_prompts = [] # list of lists for i in range(distributed_state.num_processes): From e793d7780d779855f23210d1c88368fd9286666e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 4 Feb 2024 17:31:01 +0900 Subject: [PATCH 4/4] reduce peak VRAM in sample gen --- library/train_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 1377997c..32198774 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4820,6 +4820,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p controlnet=controlnet, controlnet_image=controlnet_image, ) + + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() + image = pipeline.latents_to_image(latents)[0] # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list