diff --git a/flux_train_network.py b/flux_train_network.py index b3aebecc..5cd1b9d5 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -2,7 +2,7 @@ import argparse import copy import math import random -from typing import Any, Optional +from typing import Any, Optional, Union import torch from accelerate import Accelerator @@ -36,8 +36,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.is_schnell: Optional[bool] = None self.is_swapping_blocks: bool = False - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) if args.fp8_base_unet: @@ -80,6 +80,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): args.blocks_to_swap = 18 # 18 is safe for most cases train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) # TODO check this def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models diff --git a/library/train_util.py b/library/train_util.py index 4d143c37..56fea4a8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2893,6 +2893,9 @@ class MinimalDataset(BaseDataset): """ raise NotImplementedError + def get_resolutions(self) -> List[Tuple[int, int]]: + return [] + def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset: module = ".".join(args.dataset_class.split(".")[:-1]) @@ -6520,4 +6523,7 @@ class LossRecorder: @property def moving_average(self) -> float: - return self.loss_total / len(self.loss_list) + losses = len(self.loss_list) + if losses == 0: + return 0 + return self.loss_total / losses diff --git a/requirements.txt b/requirements.txt index e0091749..de39f588 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ voluptuous==0.13.1 huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 +numpy<=2.0 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 diff --git a/sd3_train_network.py b/sd3_train_network.py index c7417802..dcf497f5 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -2,7 +2,7 @@ import argparse import copy import math import random -from typing import Any, Optional +from typing import Any, Optional, Union import torch from accelerate import Accelerator @@ -26,7 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): super().__init__() self.sample_prompts_te_outputs = None - def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): # super().assert_extra_args(args, train_dataset_group) # sdxl_train_util.verify_sdxl_training_args(args) @@ -56,9 +56,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません" train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) # TODO check this # enumerate resolutions from dataset for positional embeddings - self.resolutions = train_dataset_group.get_resolutions() + resolutions = train_dataset_group.get_resolutions() + if val_dataset_group is not None: + resolutions = resolutions + val_dataset_group.get_resolutions() + self.resolutions = resolutions def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d45df6e0..eb09831e 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,5 +1,5 @@ import argparse -from typing import List, Optional +from typing import List, Optional, Union import torch from accelerate import Accelerator @@ -23,8 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) sdxl_train_util.verify_sdxl_training_args(args) if args.cache_text_encoder_outputs: @@ -37,6 +37,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" train_dataset_group.verify_bucket_reso_steps(32) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 821a6955..bf56faf3 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -18,11 +18,12 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR self.is_sdxl = True - def assert_extra_args(self, args, train_dataset_group): - super().assert_extra_args(args, train_dataset_group) + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): + super().assert_extra_args(args, train_dataset_group, val_dataset_group) sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) train_dataset_group.verify_bucket_reso_steps(32) + val_dataset_group.verify_bucket_reso_steps(32) def load_target_model(self, args, weight_dtype, accelerator): ( diff --git a/train_network.py b/train_network.py index e7d93a10..2c3bb2aa 100644 --- a/train_network.py +++ b/train_network.py @@ -3,7 +3,7 @@ import argparse import math import os import typing -from typing import Any, List +from typing import Any, List, Union, Optional import sys import random import time @@ -124,8 +124,10 @@ class NetworkTrainer: return logs - def assert_extra_args(self, args, train_dataset_group): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): train_dataset_group.verify_bucket_reso_steps(64) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) @@ -512,7 +514,7 @@ class NetworkTrainer: val_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - self.assert_extra_args(args, train_dataset_group) # may change some args + self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args # acceleratorを準備する logger.info("preparing accelerator") @@ -1414,7 +1416,9 @@ class NetworkTrainer: args, text_encoding_strategy, tokenize_strategy, - is_train=False + is_train=False, + train_text_encoder=False, + train_unet=False ) current_loss = loss.detach().item() @@ -1474,7 +1478,9 @@ class NetworkTrainer: args, text_encoding_strategy, tokenize_strategy, - is_train=False + is_train=False, + train_text_encoder=False, + train_unet=False ) current_loss = loss.detach().item() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 113f3599..0c6568b0 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -2,7 +2,7 @@ import argparse import math import os from multiprocessing import Value -from typing import Any, List +from typing import Any, List, Optional, Union import toml from tqdm import tqdm @@ -99,9 +99,12 @@ class TextualInversionTrainer: self.vae_scale_factor = 0.18215 self.is_sdxl = False - def assert_extra_args(self, args, train_dataset_group): + def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): train_dataset_group.verify_bucket_reso_steps(64) + if val_dataset_group is not None: + val_dataset_group.verify_bucket_reso_steps(64) + def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet @@ -325,7 +328,7 @@ class TextualInversionTrainer: train_dataset_group = train_util.load_arbitrary_dataset(args) val_dataset_group = None - self.assert_extra_args(args, train_dataset_group) + self.assert_extra_args(args, train_dataset_group, val_dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0)