diff --git a/fine_tune.py b/fine_tune.py index 741e9c85..86260754 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -243,24 +243,19 @@ def train(args): text_encoder.to(weight_dtype) if args.deepspeed: - # wrapping model - import deepspeed - if args.offload_optimizer_device is not None: - accelerator.print('[DeepSpeed] start to manually build cpu_adam.') - deepspeed.ops.op_builder.CPUAdamBuilder().load() - accelerator.print('[DeepSpeed] building cpu_adam done.') - class DeepSpeedModel(torch.nn.Module): - def __init__(self, unet, text_encoder) -> None: - super().__init__() - self.unet = unet - self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder) - def get_models(self): - return self.unet, self.text_encoders - ds_model = DeepSpeedModel(unet, text_encoders) + training_models_dict = {} + training_models_dict["unet"] = unet + if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder + + ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) - # Now, ds_model is an instance of DeepSpeedEngine. - unet, text_encoders = ds_model.get_models() # for compatiblility - text_encoder = text_encoders + + training_models = [] + unet = ds_model.models["unet"] + training_models.append(unet) + if args.train_text_encoder: + text_encoder = ds_model.models["text_encoder"] + training_models.append(text_encoder) else: # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: diff --git a/library/train_util.py b/library/train_util.py index 61c83624..334aaa21 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3959,27 +3959,7 @@ def prepare_accelerator(args: argparse.Namespace): else None, ) kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) - deepspeed_plugin = None - if args.deepspeed: - deepspeed_plugin = DeepSpeedPlugin( - zero_stage=args.zero_stage, - gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm, - offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, - offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path, - zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model, - ) - deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size - deepspeed_plugin.deepspeed_config['train_batch_size'] = \ - args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE']) - deepspeed_plugin.set_mixed_precision(args.mixed_precision) - if args.mixed_precision.lower() == "fp16": - deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0 - if args.full_fp16 or args.fp16_master_weights_and_gradients: - if args.offload_optimizer_device == "cpu": - deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True - print("[DeepSpeed] full fp16 enable.") - else: - print("full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam.") + deepspeed_plugin = prepare_deepspeed_plugin(args) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -3992,6 +3972,62 @@ def prepare_accelerator(args: argparse.Namespace): ) return accelerator +def prepare_deepspeed_plugin(args: argparse.Namespace): + if args.deepspeed is None: return None + try: + import deepspeed + except ImportError as e: + print("deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed") + exit(1) + + deepspeed_plugin = DeepSpeedPlugin( + zero_stage=args.zero_stage, + gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm, + offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path, + offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path, + zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model, + ) + deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size + deepspeed_plugin.deepspeed_config['train_batch_size'] = \ + args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE']) + deepspeed_plugin.set_mixed_precision(args.mixed_precision) + if args.mixed_precision.lower() == "fp16": + deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0 # preventing overflow. + if args.full_fp16 or args.fp16_master_weights_and_gradients: + if args.offload_optimizer_device == "cpu" and args.zero_stage == 2: + deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True + print("[DeepSpeed] full fp16 enable.") + else: + print("[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage.") + + if args.offload_optimizer_device is not None: + print('[DeepSpeed] start to manually build cpu_adam.') + deepspeed.ops.op_builder.CPUAdamBuilder().load() + print('[DeepSpeed] building cpu_adam done.') + + return deepspeed_plugin + +def prepare_deepspeed_model(args: argparse.Namespace, **models): + class DeepSpeedWrapper(torch.nn.Module): + def __init__(self, **kw_models) -> None: + super().__init__() + self.models = torch.nn.ModuleDict() + + for key, model in kw_models.items(): + if isinstance(model, list): + model = torch.nn.ModuleList(model) + assert isinstance(model, torch.nn.Module), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}" + self.models.update( + torch.nn.ModuleDict( + {key: model} + ) + ) + + def get_models(self): + return self.models + + ds_model = DeepSpeedWrapper(**models) + return ds_model def prepare_dtype(args: argparse.Namespace): weight_dtype = torch.float32 diff --git a/sdxl_train.py b/sdxl_train.py index 6ffb1bba..2f1a5ce6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -391,28 +391,29 @@ def train(args): text_encoder2.to(weight_dtype) if args.deepspeed: - # Wrapping model for DeepSpeed - import deepspeed - if args.offload_optimizer_device is not None: - accelerator.print('[DeepSpeed] start to manually build cpu_adam.') - deepspeed.ops.op_builder.CPUAdamBuilder().load() - accelerator.print('[DeepSpeed] building cpu_adam done.') - - class DeepSpeedModel(torch.nn.Module): - def __init__(self, unet, text_encoder) -> None: - super().__init__() - self.unet = unet - self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder) - - def get_models(self): - return self.unet, self.text_encoders - text_encoders = [text_encoder1, text_encoder2] - ds_model = DeepSpeedModel(unet, text_encoders) + training_models_dict = {} + if train_unet: + training_models_dict["unet"] = unet + if train_text_encoder1: + text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) + text_encoder1.text_model.final_layer_norm.requires_grad_(False) + training_models_dict["text_encoder1"] = text_encoder1 + if train_text_encoder2: + training_models_dict["text_encoder2"] = text_encoder2 + ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) - # Now, ds_model is an instance of DeepSpeedEngine. - unet, text_encoders = ds_model.get_models() # for compatiblility - text_encoder1, text_encoder2 = text_encoder = text_encoders - training_models = [unet, text_encoder1, text_encoder2] + + training_models = [] # override training_models + if train_unet: + unet = ds_model.models["unet"] + training_models.append(unet) + if train_text_encoder1: + text_encoder1 = ds_model.models["text_encoder1"] + training_models.append(text_encoder1) + if train_text_encoder2: + text_encoder2 = ds_model.models["text_encoder2"] + training_models.append(text_encoder2) + else: # acceleratorがなんかよろしくやってくれるらしい if train_unet: unet = accelerator.prepare(unet) diff --git a/train_db.py b/train_db.py index c336a1c1..f188d7bd 100644 --- a/train_db.py +++ b/train_db.py @@ -216,25 +216,20 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if args.deepspeed: - # wrapping model - import deepspeed - if args.offload_optimizer_device is not None: - accelerator.print('[DeepSpeed] start to manually build cpu_adam.') - deepspeed.ops.op_builder.CPUAdamBuilder().load() - accelerator.print('[DeepSpeed] building cpu_adam done.') - class DeepSpeedModel(torch.nn.Module): - def __init__(self, unet, text_encoder) -> None: - super().__init__() - self.unet = unet - self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder) - - def get_models(self): - return self.unet, self.text_encoders - ds_model = DeepSpeedModel(unet, text_encoders) + training_models_dict = {} + training_models_dict["unet"] = unet + if train_text_encoder: training_models_dict["text_encoder"] = text_encoder + + ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) - # Now, ds_model is an instance of DeepSpeedEngine. - unet, text_encoders = ds_model.get_models() # for compatiblility - text_encoder = text_encoders + + training_models = [] + unet = ds_model.models["unet"] + training_models.append(unet) + if train_text_encoder: + text_encoder = ds_model.models["text_encoder"] + training_models.append(text_encoder) + else: if train_text_encoder: unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( diff --git a/train_network.py b/train_network.py index cc445d39..dfa17eb3 100644 --- a/train_network.py +++ b/train_network.py @@ -410,26 +410,22 @@ class NetworkTrainer: # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: - # wrapping model - import deepspeed - if args.offload_optimizer_device is not None: - accelerator.print('[DeepSpeed] start to manually build cpu_adam.') - deepspeed.ops.op_builder.CPUAdamBuilder().load() - accelerator.print('[DeepSpeed] building cpu_adam done.') - class DeepSpeedModel(torch.nn.Module): - def __init__(self, unet, text_encoder, network) -> None: - super().__init__() - self.unet = unet - self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder) - self.network = network - - def get_models(self): - return self.unet, self.text_encoders, self.network - ds_model = DeepSpeedModel(unet, text_encoders, network) + training_models_dict = {} + if train_unet: training_models_dict["unet"] = unet + if train_text_encoder: training_models_dict["text_encoder"] = text_encoders + training_models_dict["network"] = network + + ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler) - # Now, ds_model is an instance of DeepSpeedEngine. - unet, text_encoders, network = ds_model.get_models() # for compatiblility - text_encoder = text_encoders + + if train_unet: unet = ds_model.models["unet"] + if train_text_encoder: + text_encoder = ds_model.models["text_encoder"] + if len(ds_model.models["text_encoder"]) > 1: + text_encoders = text_encoder + else: + text_encoders = [text_encoder] + else: if train_unet: unet = accelerator.prepare(unet)