add controlnet training

This commit is contained in:
ddPn08
2023-05-31 14:13:15 +09:00
parent 4f8ce00477
commit 62d00b4520
4 changed files with 1075 additions and 10 deletions

View File

@@ -33,8 +33,10 @@ from . import train_util
from .train_util import (
DreamBoothSubset,
FineTuningSubset,
ControlNetSubset,
DreamBoothDataset,
FineTuningDataset,
ControlNetDataset,
DatasetGroup,
)
@@ -70,6 +72,11 @@ class DreamBoothSubsetParams(BaseSubsetParams):
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
@dataclass
class ControlNetSubsetParams(BaseSubsetParams):
conditioning_data_dir: str = None
caption_extension: str = ".caption"
@dataclass
class BaseDatasetParams:
tokenizer: CLIPTokenizer = None
@@ -96,6 +103,15 @@ class FineTuningDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
@dataclass
class ControlNetDatasetParams(BaseDatasetParams):
batch_size: int = 1
enable_bucket: bool = False
min_bucket_reso: int = 256
max_bucket_reso: int = 1024
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
@dataclass
class SubsetBlueprint:
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
@@ -103,6 +119,7 @@ class SubsetBlueprint:
@dataclass
class DatasetBlueprint:
is_dreambooth: bool
is_controlnet: bool
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
subsets: Sequence[SubsetBlueprint]
@@ -163,6 +180,13 @@ class ConfigSanitizer:
Required("metadata_file"): str,
"image_dir": str,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
}
CN_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
Required("conditioning_data_dir"): str,
}
# datasets schema
DATASET_ASCENDABLE_SCHEMA = {
@@ -192,8 +216,8 @@ class ConfigSanitizer:
"dataset_repeats": "num_repeats",
}
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -208,6 +232,13 @@ class ConfigSanitizer:
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.cn_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
self.CN_SUBSET_DISTINCT_SCHEMA,
self.CN_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.db_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -223,13 +254,23 @@ class ConfigSanitizer:
{"subsets": [self.ft_subset_schema]},
)
if support_dreambooth and support_finetuning:
self.cn_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.CN_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
{"subsets": [self.cn_subset_schema]},
)
if support_dreambooth and support_finetuning and support_controlnet:
def validate_flex_dataset(dataset_config: dict):
subsets_config = dataset_config.get("subsets", [])
if all(["conditioning_data_dir" in subset for subset in subsets_config]):
return Schema(self.cn_dataset_schema)(dataset_config)
# check dataset meets FT style
# NOTE: all FT subsets should have "metadata_file"
if all(["metadata_file" in subset for subset in subsets_config]):
elif all(["metadata_file" in subset for subset in subsets_config]):
return Schema(self.ft_dataset_schema)(dataset_config)
# check dataset meets DB style
# NOTE: all DB subsets should have no "metadata_file"
@@ -241,13 +282,16 @@ class ConfigSanitizer:
self.dataset_schema = validate_flex_dataset
elif support_dreambooth:
self.dataset_schema = self.db_dataset_schema
else:
elif support_finetuning:
self.dataset_schema = self.ft_dataset_schema
elif support_controlnet:
self.dataset_schema = self.cn_dataset_schema
self.general_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
@@ -318,7 +362,11 @@ class BlueprintGenerator:
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
subsets = dataset_config.get("subsets", [])
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
if is_dreambooth:
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
if is_controlnet:
subset_params_klass = ControlNetSubsetParams
dataset_params_klass = ControlNetDatasetParams
elif is_dreambooth:
subset_params_klass = DreamBoothSubsetParams
dataset_params_klass = DreamBoothDatasetParams
else:
@@ -333,7 +381,7 @@ class BlueprintGenerator:
params = self.generate_params_by_fallbacks(dataset_params_klass,
[dataset_config, general_config, argparse_config, runtime_params])
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
@@ -361,10 +409,13 @@ class BlueprintGenerator:
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.is_dreambooth:
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
@@ -379,6 +430,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
@@ -421,7 +473,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
else:
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")
@@ -479,6 +531,31 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
return subsets_config
def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"):
def generate(base_dir: Optional[str]):
if base_dir is None:
return []
base_dir: Path = Path(base_dir)
if not base_dir.is_dir():
return []
subsets_config = []
for subdir in base_dir.iterdir():
if not subdir.is_dir():
continue
subset_config = {"image_dir": str(subdir), "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
subsets_config.append(subset_config)
return subsets_config
subsets_config = []
subsets_config += generate(train_data_dir, False)
return subsets_config
def load_user_config(file: str) -> dict:
file: Path = Path(file)
if not file.is_file():

View File

@@ -732,6 +732,82 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
return new_state_dict
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
unet_conversion_map = [
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
]
unet_conversion_map_resnet = [
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
for i in range(4):
for j in range(2):
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
controlnet_cond_embedding_names = (
["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
)
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
sd_prefix = f"input_hint_block.{i*2}."
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
for i in range(12):
hf_prefix = f"controlnet_down_blocks.{i}."
sd_prefix = f"zero_convs.{i}.0."
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
mapping = {k: k for k in controlnet_state_dict.keys()}
for sd_name, diffusers_name in unet_conversion_map:
mapping[diffusers_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, diffusers_part in unet_conversion_map_resnet:
v = v.replace(diffusers_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, diffusers_part in unet_conversion_map_layer:
v = v.replace(diffusers_part, sd_part)
mapping[k] = v
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#

View File

@@ -403,6 +403,54 @@ class FineTuningSubset(BaseSubset):
return self.metadata_file == other.metadata_file
class ControlNetSubset(BaseSubset):
def __init__(
self,
image_dir: str,
conditioning_data_dir: str,
caption_extension: str,
num_repeats,
shuffle_caption,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
keep_tokens,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
)
self.conditioning_data_dir = conditioning_data_dir
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, ControlNetSubset):
return NotImplemented
return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool
@@ -1387,6 +1435,274 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
class ControlNetDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[ControlNetSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
self.conditioning_image_data: Dict[str, ImageInfo] = {}
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.latents_cache = None
self.num_reg_images = 0
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert (
min(resolution) >= min_bucket_reso
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert (
max(resolution) <= max_bucket_reso
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
break
return caption
def load_controlnet_dir(subset: ControlNetSubset):
if not os.path.isdir(subset.image_dir):
print(f"not directory: {subset.image_dir}")
return [], []
if not os.path.isdir(subset.conditioning_data_dir):
print(f"not directory: {subset.conditioning_data_dir}")
return [], []
img_paths = glob_images(subset.image_dir, "*")
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
img_paths = sorted(img_paths)
conditioning_img_paths = sorted(conditioning_img_paths)
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files")
img_basenames = [os.path.basename(img) for img in img_paths]
conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths]
missing_imgs = []
extra_imgs = []
for img in img_basenames:
if img not in conditioning_img_basenames:
missing_imgs.append(img)
for img in conditioning_img_basenames:
if img not in img_basenames:
extra_imgs.append(img)
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None:
print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}")
captions.append("")
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
print(
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
)
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(missing_caption + f"... and {remaining_missing_captions} more")
break
print(missing_caption)
return img_paths, conditioning_img_paths, captions
print("prepare images.")
num_train_images = 0
for subset in subsets:
if subset.num_repeats < 1:
print(
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
)
continue
if subset in self.subsets:
print(
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
)
continue
img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset)
if len(img_paths) < 1:
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
continue
num_train_images += subset.num_repeats * len(img_paths)
for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions):
info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path)
setattr(info, "cond_img_path", cond_img_path)
self.register_image(info, subset)
subset.img_count = len(img_paths)
self.subsets.append(subset)
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
self.conditioning_image_transforms = transforms.Compose(
[
transforms.ToTensor(),
]
)
def __getitem__(self, index):
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
loss_weights = []
captions = []
input_ids_list = []
latents_list = []
images = []
conditioning_images = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(1.0)
# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
latents = torch.FloatTensor(latents)
image = None
else:
# 画像を読み込み、必要ならcropする
img = self.load_image(image_info.absolute_path)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
else:
im_h, im_w = img.shape[0:2]
assert (
im_h == self.height and im_w == self.width
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
if aug is not None:
img = aug(image=img)["image"]
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
images.append(image)
latents_list.append(latents)
caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer)
else:
token_caption = self.get_input_ids(caption)
input_ids_list.append(token_caption)
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
cond_img = self.load_image(image_info.cond_img_path)
if self.enable_bucket:
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
cond_img = self.conditioning_image_transforms(cond_img)
conditioning_images.append(cond_img)
conditioning_images = torch.stack(conditioning_images)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
if self.token_padding_disabled:
# padding=True means pad in the batch
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
else:
# batch processing seems to be good
example["input_ids"] = torch.stack(input_ids_list)
if images[0] is not None:
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
else:
images = None
example["images"] = images
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float()
return example
# behave as Dataset mock
class DatasetGroup(torch.utils.data.ConcatDataset):
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
@@ -1636,6 +1952,8 @@ def get_git_revision_hash() -> str:
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
replace_attentions_for_hypernetwork()
# unet is not used currently, but it is here for future use
unet.enable_xformers_memory_efficient_attention()
return
if mem_eff_attn:
unet.set_attn_processor(FlashAttnProcessor())
elif xformers:

594
train_controlnet.py Normal file
View File

@@ -0,0 +1,594 @@
import argparse
import gc
import math
import os
import random
import time
from multiprocessing import Value
from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
pyramid_noise_like,
apply_noise_offset,
)
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_controlnet_from_original_ckpt,
)
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {
"loss/current": current_loss,
"loss/average": avr_loss,
"lr": lr_scheduler.get_last_lr()[0],
}
if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = (
lr_scheduler.optimizers[-1].param_groups[0]["d"]
* lr_scheduler.optimizers[-1].param_groups[0]["lr"]
)
return logs
def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension,
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
blueprint.dataset_group
)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = (
train_dataset_group if args.max_data_loader_n_workers == 0 else None
)
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator
)
if args.controlnet_model_name_or_path:
if os.path.isfile(args.controlnet_model_name_or_path):
controlnet = download_controlnet_from_original_ckpt(
args.controlnet_model_name_or_path
)
else:
controlnet = ControlNetModel.from_pretrained(
args.controlnet_model_name_or_path
)
else:
controlnet = ControlNetModel.from_unet(unet)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
controlnet.enable_gradient_checkpointing()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = controlnet.parameters()
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(
args, trainable_params
)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(
args.max_data_loader_n_workers, os.cpu_count() - 1
) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader)
/ accelerator.num_processes
/ args.gradient_accumulation_steps
)
if is_main_process:
print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(
args, optimizer, accelerator.num_processes
)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
controlnet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.to(accelerator.device)
text_encoder.to(accelerator.device)
# transform DDP after prepare
controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet
controlnet.train()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = (
math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
)
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
if is_main_process:
print("running training / 学習開始")
print(
f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}"
)
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
)
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(
range(args.max_train_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="steps",
)
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
if accelerator.is_main_process:
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name
)
loss_list = []
loss_total = 0.0
del train_dataset_group
# function for saving/removing
def save_model(ckpt_name, model, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}")
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(ckpt_file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, ckpt_file)
else:
torch.save(state_dict, ckpt_file)
if args.huggingface_repo_id is not None:
huggingface_util.upload(
args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload
)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs):
if is_main_process:
print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(controlnet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(
batch["images"].to(dtype=weight_dtype)
).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.gethidden_states(
args, input_ids, tokenizer, text_encoder, weight_dtype
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = apply_noise_offset(
latents, noise, args.noise_offset, args.adaptive_noise_scale
)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(
noise,
latents.device,
args.multires_noise_iterations,
args.multires_noise_discount,
)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
with accelerator.autocast():
down_block_res_samples, mid_block_res_sample = controlnet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_image,
return_dict=False,
)
# Predict the noise residual
noise_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states,
down_block_additional_residuals=[
sample.to(dtype=weight_dtype)
for sample in down_block_res_samples
],
mid_block_additional_residual=mid_block_res_sample.to(
dtype=weight_dtype
),
).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(
noise_pred.float(), target.float(), reduction="none"
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(
loss, timesteps, noise_scheduler, args.min_snr_gamma
)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = controlnet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
)
# 指定ステップごとにモデルを保存
if (
args.save_every_n_steps is not None
and global_step % args.save_every_n_steps == 0
):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(
args, "." + args.save_model_as, global_step
)
save_model(
ckpt_name, unwrap_model(controlnet), global_step, epoch
)
if args.save_state:
train_util.save_and_remove_state_stepwise(
args, accelerator, global_step
)
remove_step_no = train_util.get_remove_step_no(
args, global_step
)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(
args, "." + args.save_model_as, remove_step_no
)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (
epoch + 1
) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(
args, "." + args.save_model_as, epoch + 1
)
save_model(ckpt_name, unwrap_model(controlnet), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(
args, "." + args.save_model_as, remove_epoch_no
)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(
args, accelerator, epoch + 1
)
train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
)
# end of epoch
if is_main_process:
controlnet = unwrap_model(controlnet)
accelerator.end_training()
if is_main_process and args.save_state:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(
ckpt_name, controlnet, global_step, num_train_epochs, force_sync_upload=True
)
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)