mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
@@ -14,6 +14,13 @@ The command to install PyTorch is as follows:
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Jan 25, 2025:
|
||||
|
||||
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
|
||||
- For details on how to set it up, please refer to the PR. The documentation will be updated as needed.
|
||||
- It will be added to other scripts as well.
|
||||
- As a current limitation, validation loss is not supported when `--block_to_swap` is specified.
|
||||
|
||||
Dec 15, 2024:
|
||||
|
||||
- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu!
|
||||
|
||||
@@ -91,9 +91,10 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -138,9 +138,10 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -126,9 +126,10 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -2,7 +2,7 @@ import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@@ -36,8 +36,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.fp8_base_unet:
|
||||
@@ -80,6 +80,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.blocks_to_swap = 18 # 18 is safe for most cases
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
@@ -339,6 +341,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -375,7 +378,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# if not args.split_mode:
|
||||
# normal forward
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
img=img,
|
||||
@@ -420,7 +423,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet):
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
"""
|
||||
|
||||
return model_pred
|
||||
|
||||
@@ -73,6 +73,8 @@ class BaseSubsetParams:
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -102,6 +104,8 @@ class BaseDatasetParams:
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
network_multiplier: float = 1.0
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -113,8 +117,7 @@ class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineTuningDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
@@ -234,6 +237,8 @@ class ConfigSanitizer:
|
||||
"enable_bucket": bool,
|
||||
"max_bucket_reso": int,
|
||||
"min_bucket_reso": int,
|
||||
"validation_seed": int,
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
}
|
||||
@@ -462,119 +467,136 @@ class BlueprintGenerator:
|
||||
|
||||
return default_value
|
||||
|
||||
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
|
||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
extra_dataset_params = {}
|
||||
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
# DreamBooth datasets support splitting training and validation datasets
|
||||
extra_dataset_params = {"is_training_dataset": True}
|
||||
else:
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
|
||||
datasets.append(dataset)
|
||||
|
||||
# print info
|
||||
info = ""
|
||||
for i, dataset in enumerate(datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(
|
||||
f"""\
|
||||
[Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
network_multiplier: {dataset.network_multiplier}
|
||||
"""
|
||||
)
|
||||
val_datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
if dataset_blueprint.params.validation_split < 0.0 or dataset_blueprint.params.validation_split > 1.0:
|
||||
logging.warning(f"Dataset param `validation_split` ({dataset_blueprint.params.validation_split}) is not a valid number between 0.0 and 1.0, skipping validation split...")
|
||||
continue
|
||||
|
||||
if dataset.enable_bucket:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
# if the dataset isn't setting a validation split, there is no current validation dataset
|
||||
if dataset_blueprint.params.validation_split == 0.0:
|
||||
continue
|
||||
|
||||
extra_dataset_params = {}
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
# DreamBooth datasets support splitting training and validation datasets
|
||||
extra_dataset_params = {"is_training_dataset": False}
|
||||
else:
|
||||
info += "\n"
|
||||
subset_klass = FineTuningSubset
|
||||
dataset_klass = FineTuningDataset
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
[Subset {j} of Dataset {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||
caption_separator: {subset.caption_separator}
|
||||
secondary_separator: {subset.secondary_separator}
|
||||
enable_wildcard: {subset.enable_wildcard}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min}
|
||||
token_warmup_step: {subset.token_warmup_step}
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
||||
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params), **extra_dataset_params)
|
||||
val_datasets.append(dataset)
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
elif not is_controlnet:
|
||||
info += indent(
|
||||
dedent(
|
||||
f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""
|
||||
),
|
||||
" ",
|
||||
)
|
||||
def print_info(_datasets, dataset_type: str):
|
||||
info = ""
|
||||
for i, dataset in enumerate(_datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
logger.info(f"{info}")
|
||||
if dataset.enable_bucket:
|
||||
info += indent(dedent(f"""\
|
||||
min_bucket_reso: {dataset.min_bucket_reso}
|
||||
max_bucket_reso: {dataset.max_bucket_reso}
|
||||
bucket_reso_steps: {dataset.bucket_reso_steps}
|
||||
bucket_no_upscale: {dataset.bucket_no_upscale}
|
||||
\n"""), " ")
|
||||
else:
|
||||
info += "\n"
|
||||
|
||||
for j, subset in enumerate(dataset.subsets):
|
||||
info += indent(dedent(f"""\
|
||||
[Subset {j} of {dataset_type} {i}]
|
||||
image_dir: "{subset.image_dir}"
|
||||
image_count: {subset.img_count}
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
color_aug: {subset.color_aug}
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""), " ")
|
||||
|
||||
if is_dreambooth:
|
||||
info += indent(dedent(f"""\
|
||||
is_reg: {subset.is_reg}
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""), " ")
|
||||
elif not is_controlnet:
|
||||
info += indent(dedent(f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""), " ")
|
||||
|
||||
logger.info(info)
|
||||
|
||||
print_info(datasets, "Dataset")
|
||||
|
||||
if len(val_datasets) > 0:
|
||||
print_info(val_datasets, "Validation Dataset")
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
|
||||
for i, dataset in enumerate(datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
logger.info(f"[Prepare dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return DatasetGroup(datasets)
|
||||
for i, dataset in enumerate(val_datasets):
|
||||
logger.info(f"[Prepare validation dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
return (
|
||||
DatasetGroup(datasets),
|
||||
DatasetGroup(val_datasets) if val_datasets else None
|
||||
)
|
||||
|
||||
|
||||
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
import torch
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union
|
||||
from .utils import setup_logging
|
||||
|
||||
@@ -63,7 +65,7 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
||||
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
@@ -74,13 +76,13 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
|
||||
return loss
|
||||
|
||||
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
loss = loss * scale
|
||||
return loss
|
||||
|
||||
|
||||
def get_snr_scale(timesteps, noise_scheduler):
|
||||
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
scale = snr_t / (snr_t + 1)
|
||||
@@ -89,14 +91,14 @@ def get_snr_scale(timesteps, noise_scheduler):
|
||||
return scale
|
||||
|
||||
|
||||
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
||||
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
return loss
|
||||
|
||||
|
||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
|
||||
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
if v_prediction:
|
||||
@@ -453,7 +455,7 @@ def get_weighted_text_embeddings(
|
||||
|
||||
|
||||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4) -> torch.FloatTensor:
|
||||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||||
for i in range(iterations):
|
||||
@@ -466,7 +468,7 @@ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
|
||||
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale) -> torch.FloatTensor:
|
||||
if noise_offset is None:
|
||||
return noise
|
||||
if adaptive_noise_scale is not None:
|
||||
@@ -482,7 +484,7 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
return noise
|
||||
|
||||
|
||||
def apply_masked_loss(loss, batch):
|
||||
def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
if "conditioning_images" in batch:
|
||||
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
|
||||
@@ -40,7 +40,7 @@ class SdTokenizeStrategy(TokenizeStrategy):
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
tokens_list = []
|
||||
weights_list = []
|
||||
|
||||
@@ -12,6 +12,7 @@ import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import typing
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -145,6 +146,37 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
def split_train_val(
|
||||
paths: List[str],
|
||||
is_training_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: int | None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split the dataset into train and validation
|
||||
|
||||
Shuffle the dataset based on the validation_seed or the current random seed.
|
||||
For example if the split of 0.2 of 100 images.
|
||||
[0:80] = 80 training images
|
||||
[80:] = 20 validation images
|
||||
"""
|
||||
if validation_seed is not None:
|
||||
logging.info(f"Using validation seed: {validation_seed}")
|
||||
prevstate = random.getstate()
|
||||
random.seed(validation_seed)
|
||||
random.shuffle(paths)
|
||||
random.setstate(prevstate)
|
||||
else:
|
||||
random.shuffle(paths)
|
||||
|
||||
# Split the dataset between training and validation
|
||||
if is_training_dataset:
|
||||
# Training dataset we split to the first part
|
||||
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
|
||||
else:
|
||||
# Validation dataset we split to the second part
|
||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
@@ -397,6 +429,8 @@ class BaseSubset:
|
||||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float, int],
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -424,6 +458,9 @@ class BaseSubset:
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
@@ -453,6 +490,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -478,6 +517,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -518,6 +559,8 @@ class FineTuningSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
@@ -543,6 +586,8 @@ class FineTuningSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -579,6 +624,8 @@ class ControlNetSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -604,6 +651,8 @@ class ControlNetSubset(BaseSubset):
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
@@ -1786,9 +1835,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||
|
||||
# The is_training_dataset defines the type of dataset, training or validation
|
||||
# if is_training_dataset is True -> training dataset
|
||||
# if is_training_dataset is False -> validation dataset
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
is_training_dataset: bool,
|
||||
batch_size: int,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
@@ -1799,6 +1852,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -1808,6 +1863,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
self.latents_cache = None
|
||||
self.is_training_dataset = is_training_dataset
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
@@ -1915,6 +1973,30 @@ class DreamBoothDataset(BaseDataset):
|
||||
size_set_count += 1
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
# We want to create a training and validation split. This should be improved in the future
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# short-term solution to limit what is necessary to implement validation datasets
|
||||
#
|
||||
# We split the dataset for the subset based on if we are doing a validation split
|
||||
# The self.is_training_dataset defines the type of dataset, training or validation
|
||||
# if self.is_training_dataset is True -> training dataset
|
||||
# if self.is_training_dataset is False -> validation dataset
|
||||
if self.validation_split > 0.0:
|
||||
# For regularization images we do not want to split this dataset.
|
||||
if subset.is_reg is True:
|
||||
# Skip any validation dataset for regularization images
|
||||
if self.is_training_dataset is False:
|
||||
img_paths = []
|
||||
# Otherwise the img_paths remain as original img_paths and no split
|
||||
# required for training images dataset of regularization images
|
||||
else:
|
||||
img_paths = split_train_val(
|
||||
img_paths,
|
||||
self.is_training_dataset,
|
||||
self.validation_split,
|
||||
self.validation_seed
|
||||
)
|
||||
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
@@ -1973,9 +2055,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_reg_images = 0
|
||||
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
num_repeats = subset.num_repeats if self.is_training_dataset else 1
|
||||
if num_repeats < 1:
|
||||
logger.warning(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
|
||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {num_repeats}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -1993,12 +2076,12 @@ class DreamBoothDataset(BaseDataset):
|
||||
continue
|
||||
|
||||
if subset.is_reg:
|
||||
num_reg_images += subset.num_repeats * len(img_paths)
|
||||
num_reg_images += num_repeats * len(img_paths)
|
||||
else:
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
@@ -2009,10 +2092,12 @@ class DreamBoothDataset(BaseDataset):
|
||||
subset.img_count = len(img_paths)
|
||||
self.subsets.append(subset)
|
||||
|
||||
logger.info(f"{num_train_images} train images with repeating.")
|
||||
images_split_name = "train" if self.is_training_dataset else "validation"
|
||||
logger.info(f"{num_train_images} {images_split_name} images with repeats.")
|
||||
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
logger.info(f"{num_reg_images} reg images.")
|
||||
logger.info(f"{num_reg_images} reg images with repeats.")
|
||||
if num_train_images < num_reg_images:
|
||||
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||
|
||||
@@ -2050,6 +2135,8 @@ class FineTuningDataset(BaseDataset):
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -2275,7 +2362,9 @@ class ControlNetDataset(BaseDataset):
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: float,
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
@@ -2314,6 +2403,7 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
self.dreambooth_dataset_delegate = DreamBoothDataset(
|
||||
db_subsets,
|
||||
True,
|
||||
batch_size,
|
||||
resolution,
|
||||
network_multiplier,
|
||||
@@ -2324,13 +2414,17 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_no_upscale,
|
||||
1.0,
|
||||
debug_dataset,
|
||||
validation_split,
|
||||
validation_seed,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
self.image_data = self.dreambooth_dataset_delegate.image_data
|
||||
self.batch_size = batch_size
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
|
||||
# assert all conditioning data exists
|
||||
missing_imgs = []
|
||||
@@ -2800,6 +2894,9 @@ class MinimalDataset(BaseDataset):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_resolutions(self) -> List[Tuple[int, int]]:
|
||||
return []
|
||||
|
||||
|
||||
def load_arbitrary_dataset(args, tokenizer=None) -> MinimalDataset:
|
||||
module = ".".join(args.dataset_class.split(".")[:-1])
|
||||
@@ -4544,7 +4641,6 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = os.path.splitext(args.config_file)[0]
|
||||
logger.info(args.config_file)
|
||||
|
||||
return args
|
||||
|
||||
@@ -4887,7 +4983,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
|
||||
|
||||
if optimizer_type == "RAdamScheduleFree".lower():
|
||||
optimizer_class = sf.RAdamScheduleFree
|
||||
logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}")
|
||||
@@ -5838,13 +5934,13 @@ def save_sd_model_on_train_end_common(
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||
|
||||
|
||||
def get_timesteps(min_timestep, max_timestep, b_size, device):
|
||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||
timesteps = timesteps.long().to(device)
|
||||
return timesteps
|
||||
|
||||
|
||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
@@ -5905,11 +6001,16 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
|
||||
def conditional_loss(
|
||||
model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None
|
||||
):
|
||||
"""
|
||||
NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
|
||||
"""
|
||||
if loss_type == "l2":
|
||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "l1":
|
||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "huber":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
@@ -5917,6 +6018,8 @@ def conditional_loss(
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
elif loss_type == "smooth_l1":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
@@ -6329,6 +6432,30 @@ def sample_image_inference(
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
|
||||
"""
|
||||
Initialize experiment trackers with tracker specific behaviors
|
||||
"""
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
default_tracker_name if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
import wandb
|
||||
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
||||
|
||||
# Define specific metrics to handle validation and epochs "steps"
|
||||
wandb_tracker.define_metric("epoch", hidden=True)
|
||||
wandb_tracker.define_metric("val_step", hidden=True)
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -6397,4 +6524,7 @@ class LossRecorder:
|
||||
|
||||
@property
|
||||
def moving_average(self) -> float:
|
||||
return self.loss_total / len(self.loss_list)
|
||||
losses = len(self.loss_list)
|
||||
if losses == 0:
|
||||
return 0
|
||||
return self.loss_total / losses
|
||||
|
||||
@@ -20,6 +20,7 @@ voluptuous==0.13.1
|
||||
huggingface-hub==0.24.5
|
||||
# for Image utils
|
||||
imagesize==1.4.1
|
||||
numpy<=2.0
|
||||
# for BLIP captioning
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
|
||||
@@ -149,9 +149,10 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -2,7 +2,7 @@ import argparse
|
||||
import copy
|
||||
import math
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@@ -26,7 +26,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
# super().assert_extra_args(args, train_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
@@ -56,9 +56,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
# enumerate resolutions from dataset for positional embeddings
|
||||
self.resolutions = train_dataset_group.get_resolutions()
|
||||
resolutions = train_dataset_group.get_resolutions()
|
||||
if val_dataset_group is not None:
|
||||
resolutions = resolutions + val_dataset_group.get_resolutions()
|
||||
self.resolutions = resolutions
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
@@ -312,6 +317,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -339,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
t5_attn_mask = None
|
||||
|
||||
# call model
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
# TODO support attention mask
|
||||
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
||||
|
||||
|
||||
@@ -176,9 +176,10 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -114,7 +114,7 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -123,7 +123,7 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -103,7 +103,7 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import argparse
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@@ -23,8 +23,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -37,6 +37,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import regex
|
||||
|
||||
@@ -18,11 +19,13 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
(
|
||||
|
||||
@@ -116,10 +116,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
|
||||
@@ -103,10 +103,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
|
||||
@@ -100,7 +100,7 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -89,9 +89,10 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
473
train_network.py
473
train_network.py
@@ -2,17 +2,19 @@ import importlib
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import typing
|
||||
from typing import Any, List, Union, Optional
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
from typing import Any, List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from torch.types import Number
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -20,6 +22,7 @@ init_ipex()
|
||||
from accelerate.utils import set_seed
|
||||
from accelerate import Accelerator
|
||||
from diffusers import DDPMScheduler
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||
|
||||
import library.train_util as train_util
|
||||
@@ -114,15 +117,17 @@ class NetworkTrainer:
|
||||
)
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
):
|
||||
):
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
)
|
||||
|
||||
return logs
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
@@ -196,10 +201,10 @@ class NetworkTrainer:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||
def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(images).latent_dist.sample()
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return latents * self.vae_scale_factor
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
@@ -214,6 +219,7 @@ class NetworkTrainer:
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
@@ -227,7 +233,7 @@ class NetworkTrainer:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
@@ -271,7 +277,7 @@ class NetworkTrainer:
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
@@ -308,6 +314,107 @@ class NetworkTrainer:
|
||||
|
||||
# endregion
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy: strategy_base.TextEncodingStrategy,
|
||||
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process a batch for the network
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
else:
|
||||
# latentに変換
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
|
||||
|
||||
latents = self.shift_scale_latents(args, latents)
|
||||
|
||||
text_encoder_conds = []
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids_list,
|
||||
weights_list,
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=is_train
|
||||
)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
@@ -373,10 +480,11 @@ class NetworkTrainer:
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -384,8 +492,12 @@ class NetworkTrainer:
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
||||
train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(val_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
@@ -397,8 +509,12 @@ class NetworkTrainer:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
if val_dataset_group is not None:
|
||||
assert (
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group) # may change some args
|
||||
self.assert_extra_args(args, train_dataset_group, val_dataset_group) # may change some args
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("preparing accelerator")
|
||||
@@ -444,6 +560,8 @@ class NetworkTrainer:
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.new_cache_latents(vae, accelerator)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -459,6 +577,8 @@ class NetworkTrainer:
|
||||
if text_encoder_outputs_caching_strategy is not None:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
|
||||
if val_dataset_group is not None:
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
@@ -567,6 +687,8 @@ class NetworkTrainer:
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
@@ -579,6 +701,15 @@ class NetworkTrainer:
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
@@ -654,8 +785,8 @@ class NetworkTrainer:
|
||||
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
|
||||
network=network,
|
||||
)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler
|
||||
)
|
||||
training_model = ds_model
|
||||
else:
|
||||
@@ -676,8 +807,8 @@ class NetworkTrainer:
|
||||
else:
|
||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, lr_scheduler
|
||||
network, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, val_dataloader, lr_scheduler
|
||||
)
|
||||
training_model = network
|
||||
|
||||
@@ -769,6 +900,7 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
@@ -788,6 +920,7 @@ class NetworkTrainer:
|
||||
"ss_text_encoder_lr": text_encoder_lr,
|
||||
"ss_unet_lr": args.unet_lr,
|
||||
"ss_num_train_images": train_dataset_group.num_train_images,
|
||||
"ss_num_validation_images": val_dataset_group.num_train_images if val_dataset_group is not None else 0,
|
||||
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
||||
"ss_num_batches_per_epoch": len(train_dataloader),
|
||||
"ss_num_epochs": num_train_epochs,
|
||||
@@ -835,6 +968,11 @@ class NetworkTrainer:
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": bool(args.fp8_base),
|
||||
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||
"ss_validation_seed": args.validation_seed,
|
||||
"ss_validation_split": args.validation_split,
|
||||
"ss_max_validation_steps": args.max_validation_steps,
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1051,20 +1189,15 @@ class NetworkTrainer:
|
||||
|
||||
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=train_util.get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
train_util.init_trackers(accelerator, args, "network_train")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
val_step_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
del val_dataset_group
|
||||
|
||||
# callback for step start
|
||||
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
||||
@@ -1109,10 +1242,17 @@ class NetworkTrainer:
|
||||
optimizer_eval_fn()
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
is_tracking = len(accelerator.trackers) > 0
|
||||
if is_tracking:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
validation_steps = (
|
||||
min(args.max_validation_steps, len(val_dataloader))
|
||||
if args.max_validation_steps is not None
|
||||
else len(val_dataloader)
|
||||
)
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
@@ -1132,13 +1272,14 @@ class NetworkTrainer:
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
# TRAINING
|
||||
skipped_dataloader = None
|
||||
if initial_step > 0:
|
||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
||||
@@ -1156,98 +1297,24 @@ class NetworkTrainer:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype))
|
||||
latents = latents.to(dtype=weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
latents = self.shift_scale_latents(args, latents)
|
||||
|
||||
# get multiplier for each sample
|
||||
if network_has_multiplier:
|
||||
multipliers = batch["network_multipliers"]
|
||||
# if all multipliers are same, use single multiplier
|
||||
if torch.all(multipliers == multipliers[0]):
|
||||
multipliers = multipliers[0].item()
|
||||
else:
|
||||
raise NotImplementedError("multipliers for each sample is not supported yet")
|
||||
# print(f"set multiplier: {multipliers}")
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
text_encoder_conds = []
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids_list,
|
||||
weights_list,
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet
|
||||
)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
@@ -1302,19 +1369,148 @@ class NetworkTrainer:
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
|
||||
if is_tracking:
|
||||
logs = self.generate_step_logs(
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# VALIDATION PER STEP
|
||||
should_validate_step = (
|
||||
args.validate_every_n_steps is not None
|
||||
and global_step != 0 # Skip first step
|
||||
and global_step % args.validate_every_n_steps == 0
|
||||
)
|
||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps"
|
||||
)
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/step_current": current_loss,
|
||||
"val_step": (epoch * validation_steps) + val_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if is_tracking:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# EPOCH VALIDATION
|
||||
should_validate_epoch = (
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
||||
if args.validate_every_n_epochs is not None
|
||||
else True
|
||||
)
|
||||
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="epoch validation steps"
|
||||
)
|
||||
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/epoch_current": current_loss,
|
||||
"epoch": epoch + 1,
|
||||
"val_step": (epoch * validation_steps) + val_step
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss
|
||||
logs = {
|
||||
"loss/validation/epoch_average": avr_loss,
|
||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||
"epoch": epoch + 1
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# END OF EPOCH
|
||||
if is_tracking:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
@@ -1496,9 +1692,36 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
||||
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
||||
)
|
||||
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
||||
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
||||
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
||||
parser.add_argument(
|
||||
"--validation_seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_split",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_every_n_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_every_n_epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_validation_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional, Union
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -99,9 +99,12 @@ class TextualInversionTrainer:
|
||||
self.vae_scale_factor = 0.18215
|
||||
self.is_sdxl = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
|
||||
@@ -320,11 +323,12 @@ class TextualInversionTrainer:
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
self.assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
@@ -239,7 +239,7 @@ def train(args):
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
|
||||
Reference in New Issue
Block a user