mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
use skip_first_batches to skip, close pillow image
This commit is contained in:
@@ -649,8 +649,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
|
||||
self.shuffle_buckets()
|
||||
self.current_epoch = epoch
|
||||
if epoch > self.current_epoch:
|
||||
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
||||
num_epochs = epoch - self.current_epoch
|
||||
for _ in range(num_epochs):
|
||||
self.current_epoch += 1
|
||||
self.shuffle_buckets()
|
||||
else:
|
||||
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
||||
self.current_epoch = epoch
|
||||
|
||||
def set_current_step(self, step):
|
||||
self.current_step = step
|
||||
@@ -941,7 +948,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self._length = len(self.buckets_indices)
|
||||
|
||||
def shuffle_buckets(self):
|
||||
# set random seed for this epoch
|
||||
# set random seed for this epoch: current_epoch is not incremented
|
||||
random.seed(self.seed + self.current_epoch)
|
||||
|
||||
random.shuffle(self.buckets_indices)
|
||||
@@ -2346,10 +2353,10 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||
|
||||
|
||||
def load_image(image_path):
|
||||
image = Image.open(image_path)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
with Image.open(image_path) as image:
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
|
||||
@@ -793,7 +793,6 @@ class NetworkTrainer:
|
||||
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
|
||||
)
|
||||
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
||||
initial_step *= accelerator.num_processes * args.gradient_accumulation_steps
|
||||
else:
|
||||
# if not, only epoch no is skipped for informative purpose
|
||||
epoch_to_start = initial_step // math.ceil(
|
||||
@@ -865,23 +864,26 @@ class NetworkTrainer:
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
if initial_step > len(train_dataloader):
|
||||
steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
if initial_step > steps_per_epoch:
|
||||
logger.info(f"skipping epoch {epoch+1} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= len(train_dataloader)
|
||||
initial_step -= steps_per_epoch
|
||||
continue
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
active_dataloader = train_dataloader
|
||||
if initial_step > 0:
|
||||
logger.info(f"skipping {initial_step} batches in epoch {epoch+1}")
|
||||
active_dataloader = accelerator.skip_first_batches(
|
||||
train_dataloader, initial_step * args.gradient_accumulation_steps
|
||||
)
|
||||
initial_step = 0
|
||||
|
||||
if initial_step > 0:
|
||||
# logger.info(f"skipping step {step+1} because initial_step (multiplied) is {initial_step}")
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=0) # add dummy loss
|
||||
initial_step -= 1
|
||||
continue
|
||||
for step, batch in enumerate(active_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start(text_encoder, unet)
|
||||
|
||||
Reference in New Issue
Block a user