From faadc350a42b0b0efb39e0fb052b3d5ecd2dda4d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 8 Apr 2024 23:28:38 +0900 Subject: [PATCH] use skip_first_batches to skip, close pillow image --- library/train_util.py | 21 ++++++++++++++------- train_network.py | 22 ++++++++++++---------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b9050297..fac0db56 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -649,8 +649,15 @@ class BaseDataset(torch.utils.data.Dataset): def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする - self.shuffle_buckets() - self.current_epoch = epoch + if epoch > self.current_epoch: + logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + num_epochs = epoch - self.current_epoch + for _ in range(num_epochs): + self.current_epoch += 1 + self.shuffle_buckets() + else: + logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + self.current_epoch = epoch def set_current_step(self, step): self.current_step = step @@ -941,7 +948,7 @@ class BaseDataset(torch.utils.data.Dataset): self._length = len(self.buckets_indices) def shuffle_buckets(self): - # set random seed for this epoch + # set random seed for this epoch: current_epoch is not incremented random.seed(self.seed + self.current_epoch) random.shuffle(self.buckets_indices) @@ -2346,10 +2353,10 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: def load_image(image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) + with Image.open(image_path) as image: + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) return img diff --git a/train_network.py b/train_network.py index 7528932f..7bde7421 100644 --- a/train_network.py +++ b/train_network.py @@ -793,7 +793,6 @@ class NetworkTrainer: f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります" ) logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") - initial_step *= accelerator.num_processes * args.gradient_accumulation_steps else: # if not, only epoch no is skipped for informative purpose epoch_to_start = initial_step // math.ceil( @@ -865,23 +864,26 @@ class NetworkTrainer: accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 - if initial_step > len(train_dataloader): + steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + if initial_step > steps_per_epoch: logger.info(f"skipping epoch {epoch+1} because initial_step (multiplied) is {initial_step}") - initial_step -= len(train_dataloader) + initial_step -= steps_per_epoch continue metadata["ss_epoch"] = str(epoch + 1) accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) - for step, batch in enumerate(train_dataloader): - current_step.value = global_step + active_dataloader = train_dataloader + if initial_step > 0: + logger.info(f"skipping {initial_step} batches in epoch {epoch+1}") + active_dataloader = accelerator.skip_first_batches( + train_dataloader, initial_step * args.gradient_accumulation_steps + ) + initial_step = 0 - if initial_step > 0: - # logger.info(f"skipping step {step+1} because initial_step (multiplied) is {initial_step}") - loss_recorder.add(epoch=epoch, step=step, loss=0) # add dummy loss - initial_step -= 1 - continue + for step, batch in enumerate(active_dataloader): + current_step.value = global_step with accelerator.accumulate(training_model): on_step_start(text_encoder, unet)