Compare commits

...

3 Commits

Author SHA1 Message Date
Kohya S
faadc350a4 use skip_first_batches to skip, close pillow image 2024-04-08 23:28:38 +09:00
Kohya S
6d9338f8b5 Merge branch 'dev' into resume-step-and-epoch 2024-04-08 12:50:18 +09:00
Kohya S
5f0eebaa56 experimenal impl to restore step/ecoch in resuming 2024-04-07 18:18:03 +09:00
2 changed files with 119 additions and 11 deletions

View File

@@ -649,8 +649,15 @@ class BaseDataset(torch.utils.data.Dataset):
def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch
if epoch > self.current_epoch:
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
num_epochs = epoch - self.current_epoch
for _ in range(num_epochs):
self.current_epoch += 1
self.shuffle_buckets()
else:
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
self.current_epoch = epoch
def set_current_step(self, step):
self.current_step = step
@@ -941,7 +948,7 @@ class BaseDataset(torch.utils.data.Dataset):
self._length = len(self.buckets_indices)
def shuffle_buckets(self):
# set random seed for this epoch
# set random seed for this epoch: current_epoch is not incremented
random.seed(self.seed + self.current_epoch)
random.shuffle(self.buckets_indices)
@@ -2346,10 +2353,10 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
def load_image(image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
with Image.open(image_path) as image:
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
@@ -5387,7 +5394,7 @@ class LossRecorder:
self.loss_total: float = 0.0
def add(self, *, epoch: int, step: int, loss: float) -> None:
if epoch == 0:
if epoch == 0 or step >= len(self.loss_list):
self.loss_list.append(loss)
else:
self.loss_total -= self.loss_list[step]

View File

@@ -483,6 +483,15 @@ class NetworkTrainer:
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved")
# save current ecpoch and step
train_state_file = os.path.join(output_dir, "train_state.json")
# +1 is needed because the state is saved before current_step is set from global_step
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
with open(train_state_file, "w", encoding="utf-8") as f:
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
steps_from_state = None
def load_model_hook(models, input_dir):
# remove models except network
remove_indices = []
@@ -493,6 +502,15 @@ class NetworkTrainer:
models.pop(i)
# print(f"load model hook: {len(models)} models will be loaded")
# load current epoch and step to
nonlocal steps_from_state
train_state_file = os.path.join(input_dir, "train_state.json")
if os.path.exists(train_state_file):
with open(train_state_file, "r", encoding="utf-8") as f:
data = json.load(f)
steps_from_state = data["current_step"] + 1 # because
logger.info(f"load train state from {train_state_file}: {data}")
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -736,7 +754,52 @@ class NetworkTrainer:
if key in metadata:
minimum_metadata[key] = metadata[key]
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
# calculate steps to skip when resuming or starting from a specific step
initial_step = 0
if args.initial_epoch is not None or args.initial_step is not None:
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
if steps_from_state is not None:
logger.warning(
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
)
if args.initial_step is not None:
initial_step = args.initial_step
else:
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
initial_step = (args.initial_epoch - 1) * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
else:
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
if steps_from_state is not None:
initial_step = steps_from_state
steps_from_state = None
if initial_step > 0:
assert (
args.max_train_steps > initial_step
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)
epoch_to_start = 0
if initial_step > 0:
if args.skip_until_initial_step:
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
if not args.resume:
logger.info(
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
)
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
else:
# if not, only epoch no is skipped for informative purpose
epoch_to_start = initial_step // math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
initial_step = 0 # do not skip
global_step = 0
noise_scheduler = DDPMScheduler(
@@ -793,16 +856,35 @@ class NetworkTrainer:
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# training loop
for epoch in range(num_train_epochs):
if initial_step > 0:
# set starting global step calculated from initial_step. because skipping steps doesn't increment global_step
global_step = initial_step // (accelerator.num_processes * args.gradient_accumulation_steps)
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
if initial_step > steps_per_epoch:
logger.info(f"skipping epoch {epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= steps_per_epoch
continue
metadata["ss_epoch"] = str(epoch + 1)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):
active_dataloader = train_dataloader
if initial_step > 0:
logger.info(f"skipping {initial_step} batches in epoch {epoch+1}")
active_dataloader = accelerator.skip_first_batches(
train_dataloader, initial_step * args.gradient_accumulation_steps
)
initial_step = 0
for step, batch in enumerate(active_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
@@ -1101,6 +1183,25 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--skip_until_initial_step",
action="store_true",
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
)
parser.add_argument(
"--initial_epoch",
type=int,
default=None,
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
+ " / 初期エポック数、1で最初のエポック未指定時と同じ。注意initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
)
parser.add_argument(
"--initial_step",
type=int,
default=None,
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ未指定時と同じ。initial_epochを上書きする",
)
return parser