Fix loss recorder on 0. Fix validation for cached runs. Assert on validation dataset

This commit is contained in:
rockerBOO
2025-01-23 09:57:24 -05:00
parent b489082495
commit c04e5dfe92
8 changed files with 46 additions and 20 deletions

View File

@@ -2,7 +2,7 @@ import argparse
import copy
import math
import random
from typing import Any, Optional
from typing import Any, Optional, Union
import torch
from accelerate import Accelerator
@@ -36,8 +36,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
@@ -80,6 +80,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
args.blocks_to_swap = 18 # 18 is safe for most cases
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models

View File

@@ -2893,6 +2893,9 @@ class MinimalDataset(BaseDataset):
"""
raise NotImplementedError
def get_resolutions(self) -> List[Tuple[int, int]]:
return []
def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
module = ".".join(args.dataset_class.split(".")[:-1])
@@ -6520,4 +6523,7 @@ class LossRecorder:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
losses = len(self.loss_list)
if losses == 0:
return 0
return self.loss_total / losses

View File

@@ -20,6 +20,7 @@ voluptuous==0.13.1
huggingface-hub==0.24.5
# for Image utils
imagesize==1.4.1
numpy<=2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12

View File

@@ -2,7 +2,7 @@ import argparse
import copy
import math
import random
from typing import Any, Optional
from typing import Any, Optional, Union
import torch
from accelerate import Accelerator
@@ -26,7 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
# super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
@@ -56,9 +56,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
# enumerate resolutions from dataset for positional embeddings
self.resolutions = train_dataset_group.get_resolutions()
resolutions = train_dataset_group.get_resolutions()
if val_dataset_group is not None:
resolutions = resolutions + val_dataset_group.get_resolutions()
self.resolutions = resolutions
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models

View File

@@ -1,5 +1,5 @@
import argparse
from typing import List, Optional
from typing import List, Optional, Union
import torch
from accelerate import Accelerator
@@ -23,8 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:
@@ -37,6 +37,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(

View File

@@ -18,11 +18,12 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
train_dataset_group.verify_bucket_reso_steps(32)
val_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(

View File

@@ -3,7 +3,7 @@ import argparse
import math
import os
import typing
from typing import Any, List
from typing import Any, List, Union, Optional
import sys
import random
import time
@@ -124,8 +124,10 @@ class NetworkTrainer:
return logs
def assert_extra_args(self, args, train_dataset_group):
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -512,7 +514,7 @@ class NetworkTrainer:
val_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group) # may change some args
self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args
# acceleratorを準備する
logger.info("preparing accelerator")
@@ -1414,7 +1416,9 @@ class NetworkTrainer:
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False
is_train=False,
train_text_encoder=False,
train_unet=False
)
current_loss = loss.detach().item()
@@ -1474,7 +1478,9 @@ class NetworkTrainer:
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False
is_train=False,
train_text_encoder=False,
train_unet=False
)
current_loss = loss.detach().item()

View File

@@ -2,7 +2,7 @@ import argparse
import math
import os
from multiprocessing import Value
from typing import Any, List
from typing import Any, List, Optional, Union
import toml
from tqdm import tqdm
@@ -99,9 +99,12 @@ class TextualInversionTrainer:
self.vae_scale_factor = 0.18215
self.is_sdxl = False
def assert_extra_args(self, args, train_dataset_group):
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
@@ -325,7 +328,7 @@ class TextualInversionTrainer:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
self.assert_extra_args(args, train_dataset_group)
self.assert_extra_args(args, train_dataset_group, val_dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)