From 2fcbfec17873eedd70cd728f1205836dfca9ceab Mon Sep 17 00:00:00 2001 From: ykume Date: Wed, 3 May 2023 11:07:29 +0900 Subject: [PATCH] make transform_DDP more intuitive --- fine_tune.py | 2 +- library/train_util.py | 6 +++--- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index db1c8a23..9d42c873 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -229,7 +229,7 @@ def train(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # transform DDP after prepare - text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet) + text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: diff --git a/library/train_util.py b/library/train_util.py index 1a3b2ed0..cac4cdc5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2897,9 +2897,9 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): return text_encoder, vae, unet, load_stable_diffusion_format -def transform_DDP(text_encoder, unet, network=None): +def transform_if_model_is_DDP(text_encoder, unet, network=None): # Transform text_encoder, unet and network from DistributedDataParallel - return (encoder.module if type(encoder) == DDP else encoder for encoder in [text_encoder, unet, network]) + return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None) def load_target_model(args, weight_dtype, accelerator): @@ -2922,7 +2922,7 @@ def load_target_model(args, weight_dtype, accelerator): torch.cuda.empty_cache() accelerator.wait_for_everyone() - text_encoder, unet, _ = transform_DDP(text_encoder, unet, network=None) + text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet) return text_encoder, vae, unet, load_stable_diffusion_format diff --git a/train_db.py b/train_db.py index abe2ecdf..ad7a317e 100644 --- a/train_db.py +++ b/train_db.py @@ -197,7 +197,7 @@ def train(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # transform DDP after prepare - text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet) + text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) if not train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error diff --git a/train_network.py b/train_network.py index c5ec0ebd..3f95c5f7 100644 --- a/train_network.py +++ b/train_network.py @@ -262,7 +262,7 @@ def train(args): network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) # transform DDP after prepare (train_network here only) - text_encoder, unet, network = train_util.transform_DDP(text_encoder, unet, network) + text_encoder, unet, network = train_util.transform_if_model_is_DDP(text_encoder, unet, network) unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index c13fcf9f..c11a199f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -281,7 +281,7 @@ def train(args): ) # transform DDP after prepare - text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet) + text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] # print(len(index_no_updates), torch.sum(index_no_updates)) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 67d48023..5342a695 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -315,7 +315,7 @@ def train(args): ) # transform DDP after prepare - text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet) + text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # print(len(index_no_updates), torch.sum(index_no_updates))