diff --git a/library/train_util.py b/library/train_util.py index 81dffb1d..b5e6aa3f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -754,12 +754,14 @@ class BaseDataset(torch.utils.data.Dataset): img = np.array(image, np.uint8) return img - def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): + def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size, cond_img = None): image_height, image_width = image.shape[0:2] if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if exists(cond_img): + cond_img = cv2.resize(cond_img, resized_size, interpolation=cv2.INTER_AREA) image_height, image_width = image.shape[0:2] if image_width > reso[0]: @@ -767,15 +769,26 @@ class BaseDataset(torch.utils.data.Dataset): p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) # print("w", trim_size, p) image = image[:, p : p + reso[0]] + if exists(cond_img): + cond_img = cond_img[:, p : p + reso[0]] if image_height > reso[1]: trim_size = image_height - reso[1] p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size) # print("h", trim_size, p) image = image[p : p + reso[1]] + if exists(cond_img): + cond_img = cond_img[p : p + reso[1]] assert ( image.shape[0] == reso[1] and image.shape[1] == reso[0] ), f"internal error, illegal trimmed size: {image.shape}, {reso}" + + if exists(cond_img): + assert ( + cond_img.shape[0] == reso[1] and cond_img.shape[1] == reso[0] + ), f"internal error, illegal trimmed size: {cond_img.shape}, {reso}" + return image, cond_img + return image def is_latent_cacheable(self): @@ -1617,6 +1630,8 @@ class ControlNetDataset(BaseDataset): subset = self.image_to_subset[image_key] loss_weights.append(1.0) + assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}" + # image/latentsを処理する if image_info.latents is not None: # cache_latents=Trueの場合 latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped @@ -1628,10 +1643,11 @@ class ControlNetDataset(BaseDataset): else: # 画像を読み込み、必要ならcropする img = self.load_image(image_info.absolute_path) + cond_img = self.load_image(image_info.cond_img_path) im_h, im_w = img.shape[0:2] if self.enable_bucket: - img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size) + img, cond_img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size, cond_img=cond_img) else: im_h, im_w = img.shape[0:2] assert ( @@ -1649,41 +1665,18 @@ class ControlNetDataset(BaseDataset): images.append(image) latents_list.append(latents) - caption = self.process_caption(subset, image_info.caption) - if self.XTI_layers: - caption_layer = [] - for layer in self.XTI_layers: - token_strings_from = " ".join(self.token_strings) - token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - caption_ = caption.replace(token_strings_from, token_strings_to) - caption_layer.append(caption_) - captions.append(caption_layer) - else: - captions.append(caption) - if not self.token_padding_disabled: # this option might be omitted in future - if self.XTI_layers: - token_caption = self.get_input_ids(caption_layer) - else: - token_caption = self.get_input_ids(caption) - input_ids_list.append(token_caption) - - assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}" - - cond_img = self.load_image(image_info.cond_img_path) - if self.enable_bucket: - cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size) cond_img = self.conditioning_image_transforms(cond_img) conditioning_images.append(cond_img) + caption = self.process_caption(subset, image_info.caption) + captions.append(caption) + token_caption = self.get_input_ids(caption) + input_ids_list.append(token_caption) + example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) - if self.token_padding_disabled: - # padding=True means pad in the batch - example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids - else: - # batch processing seems to be good - example["input_ids"] = torch.stack(input_ids_list) + example["input_ids"] = torch.stack(input_ids_list) if images[0] is not None: images = torch.stack(images) diff --git a/train_controlnet.py b/train_controlnet.py index 263e8813..6e4e5bb8 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -141,7 +141,6 @@ def train(args): controlnet = ControlNetModel.from_pretrained(filename) - # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -168,11 +167,11 @@ def train(args): controlnet.enable_gradient_checkpointing() # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") trainable_params = controlnet.parameters() - optimizer_name, optimizer_args, optimizer = train_util.get_optimizer( + _, _, optimizer = train_util.get_optimizer( args, trainable_params ) @@ -198,10 +197,9 @@ def train(args): / accelerator.num_processes / args.gradient_accumulation_steps ) - if is_main_process: - print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" - ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -216,7 +214,7 @@ def train(args): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") + accelerator.print("enable full fp16 training.") controlnet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい @@ -258,23 +256,21 @@ def train(args): # 学習する # TODO: find a way to handle total batch size when there are multiple datasets - - if is_main_process: - print("running training / 学習開始") - print( - f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}" - ) - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) - # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print( - f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" - ) - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print( + f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}" + ) + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" + ) + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm( range(args.max_train_steps), @@ -303,11 +299,11 @@ def train(args): del train_dataset_group # function for saving/removing - def save_model(ckpt_name, model, steps, epoch_no, force_sync_upload=False): + def save_model(ckpt_name, model, force_sync_upload=False): os.makedirs(args.output_dir, exist_ok=True) ckpt_file = os.path.join(args.output_dir, ckpt_name) - print(f"\nsaving checkpoint: {ckpt_file}") + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) @@ -332,13 +328,13 @@ def train(args): def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): - print(f"removing old checkpoint: {old_ckpt_file}") + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) # training loop for epoch in range(num_train_epochs): if is_main_process: - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): @@ -470,7 +466,7 @@ def train(args): args, "." + args.save_model_as, global_step ) save_model( - ckpt_name, unwrap_model(controlnet), global_step, epoch + ckpt_name, unwrap_model(controlnet), ) if args.save_state: @@ -520,7 +516,7 @@ def train(args): ckpt_name = train_util.get_epoch_ckpt_name( args, "." + args.save_model_as, epoch + 1 ) - save_model(ckpt_name, unwrap_model(controlnet), global_step, epoch + 1) + save_model(ckpt_name, unwrap_model(controlnet)) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -561,7 +557,7 @@ def train(args): if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) save_model( - ckpt_name, controlnet, global_step, num_train_epochs, force_sync_upload=True + ckpt_name, controlnet, force_sync_upload=True ) print("model saved.")