mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
add controlnet training
This commit is contained in:
@@ -33,8 +33,10 @@ from . import train_util
|
||||
from .train_util import (
|
||||
DreamBoothSubset,
|
||||
FineTuningSubset,
|
||||
ControlNetSubset,
|
||||
DreamBoothDataset,
|
||||
FineTuningDataset,
|
||||
ControlNetDataset,
|
||||
DatasetGroup,
|
||||
)
|
||||
|
||||
@@ -70,6 +72,11 @@ class DreamBoothSubsetParams(BaseSubsetParams):
|
||||
class FineTuningSubsetParams(BaseSubsetParams):
|
||||
metadata_file: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class ControlNetSubsetParams(BaseSubsetParams):
|
||||
conditioning_data_dir: str = None
|
||||
caption_extension: str = ".caption"
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetParams:
|
||||
tokenizer: CLIPTokenizer = None
|
||||
@@ -96,6 +103,15 @@ class FineTuningDatasetParams(BaseDatasetParams):
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
|
||||
@dataclass
|
||||
class ControlNetDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
enable_bucket: bool = False
|
||||
min_bucket_reso: int = 256
|
||||
max_bucket_reso: int = 1024
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
|
||||
@dataclass
|
||||
class SubsetBlueprint:
|
||||
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
||||
@@ -103,6 +119,7 @@ class SubsetBlueprint:
|
||||
@dataclass
|
||||
class DatasetBlueprint:
|
||||
is_dreambooth: bool
|
||||
is_controlnet: bool
|
||||
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
||||
subsets: Sequence[SubsetBlueprint]
|
||||
|
||||
@@ -163,6 +180,13 @@ class ConfigSanitizer:
|
||||
Required("metadata_file"): str,
|
||||
"image_dir": str,
|
||||
}
|
||||
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
"caption_extension": str,
|
||||
}
|
||||
CN_SUBSET_DISTINCT_SCHEMA = {
|
||||
Required("image_dir"): str,
|
||||
Required("conditioning_data_dir"): str,
|
||||
}
|
||||
|
||||
# datasets schema
|
||||
DATASET_ASCENDABLE_SCHEMA = {
|
||||
@@ -192,8 +216,8 @@ class ConfigSanitizer:
|
||||
"dataset_repeats": "num_repeats",
|
||||
}
|
||||
|
||||
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
|
||||
assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
||||
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
||||
assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
||||
|
||||
self.db_subset_schema = self.__merge_dict(
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
@@ -208,6 +232,13 @@ class ConfigSanitizer:
|
||||
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
||||
)
|
||||
|
||||
self.cn_subset_schema = self.__merge_dict(
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.CN_SUBSET_DISTINCT_SCHEMA,
|
||||
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
||||
)
|
||||
|
||||
self.db_dataset_schema = self.__merge_dict(
|
||||
self.DATASET_ASCENDABLE_SCHEMA,
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
@@ -223,13 +254,23 @@ class ConfigSanitizer:
|
||||
{"subsets": [self.ft_subset_schema]},
|
||||
)
|
||||
|
||||
if support_dreambooth and support_finetuning:
|
||||
self.cn_dataset_schema = self.__merge_dict(
|
||||
self.DATASET_ASCENDABLE_SCHEMA,
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
||||
{"subsets": [self.cn_subset_schema]},
|
||||
)
|
||||
|
||||
if support_dreambooth and support_finetuning and support_controlnet:
|
||||
def validate_flex_dataset(dataset_config: dict):
|
||||
subsets_config = dataset_config.get("subsets", [])
|
||||
|
||||
if all(["conditioning_data_dir" in subset for subset in subsets_config]):
|
||||
return Schema(self.cn_dataset_schema)(dataset_config)
|
||||
# check dataset meets FT style
|
||||
# NOTE: all FT subsets should have "metadata_file"
|
||||
if all(["metadata_file" in subset for subset in subsets_config]):
|
||||
elif all(["metadata_file" in subset for subset in subsets_config]):
|
||||
return Schema(self.ft_dataset_schema)(dataset_config)
|
||||
# check dataset meets DB style
|
||||
# NOTE: all DB subsets should have no "metadata_file"
|
||||
@@ -241,13 +282,16 @@ class ConfigSanitizer:
|
||||
self.dataset_schema = validate_flex_dataset
|
||||
elif support_dreambooth:
|
||||
self.dataset_schema = self.db_dataset_schema
|
||||
else:
|
||||
elif support_finetuning:
|
||||
self.dataset_schema = self.ft_dataset_schema
|
||||
elif support_controlnet:
|
||||
self.dataset_schema = self.cn_dataset_schema
|
||||
|
||||
self.general_schema = self.__merge_dict(
|
||||
self.DATASET_ASCENDABLE_SCHEMA,
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
||||
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
|
||||
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
||||
)
|
||||
|
||||
@@ -318,7 +362,11 @@ class BlueprintGenerator:
|
||||
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
||||
subsets = dataset_config.get("subsets", [])
|
||||
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
||||
if is_dreambooth:
|
||||
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
|
||||
if is_controlnet:
|
||||
subset_params_klass = ControlNetSubsetParams
|
||||
dataset_params_klass = ControlNetDatasetParams
|
||||
elif is_dreambooth:
|
||||
subset_params_klass = DreamBoothSubsetParams
|
||||
dataset_params_klass = DreamBoothDatasetParams
|
||||
else:
|
||||
@@ -333,7 +381,7 @@ class BlueprintGenerator:
|
||||
|
||||
params = self.generate_params_by_fallbacks(dataset_params_klass,
|
||||
[dataset_config, general_config, argparse_config, runtime_params])
|
||||
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
|
||||
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
|
||||
|
||||
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
||||
|
||||
@@ -361,10 +409,13 @@ class BlueprintGenerator:
|
||||
|
||||
|
||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
|
||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||
|
||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||
if dataset_blueprint.is_dreambooth:
|
||||
if dataset_blueprint.is_controlnet:
|
||||
subset_klass = ControlNetSubset
|
||||
dataset_klass = ControlNetDataset
|
||||
elif dataset_blueprint.is_dreambooth:
|
||||
subset_klass = DreamBoothSubset
|
||||
dataset_klass = DreamBoothDataset
|
||||
else:
|
||||
@@ -379,6 +430,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
info = ""
|
||||
for i, dataset in enumerate(datasets):
|
||||
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
||||
is_controlnet = isinstance(dataset, ControlNetDataset)
|
||||
info += dedent(f"""\
|
||||
[Dataset {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
@@ -421,7 +473,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
class_tokens: {subset.class_tokens}
|
||||
caption_extension: {subset.caption_extension}
|
||||
\n"""), " ")
|
||||
else:
|
||||
elif not is_controlnet:
|
||||
info += indent(dedent(f"""\
|
||||
metadata_file: {subset.metadata_file}
|
||||
\n"""), " ")
|
||||
@@ -479,6 +531,31 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
|
||||
return subsets_config
|
||||
|
||||
|
||||
def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"):
|
||||
def generate(base_dir: Optional[str]):
|
||||
if base_dir is None:
|
||||
return []
|
||||
|
||||
base_dir: Path = Path(base_dir)
|
||||
if not base_dir.is_dir():
|
||||
return []
|
||||
|
||||
subsets_config = []
|
||||
for subdir in base_dir.iterdir():
|
||||
if not subdir.is_dir():
|
||||
continue
|
||||
|
||||
subset_config = {"image_dir": str(subdir), "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
|
||||
subsets_config.append(subset_config)
|
||||
|
||||
return subsets_config
|
||||
|
||||
subsets_config = []
|
||||
subsets_config += generate(train_data_dir, False)
|
||||
|
||||
return subsets_config
|
||||
|
||||
|
||||
def load_user_config(file: str) -> dict:
|
||||
file: Path = Path(file)
|
||||
if not file.is_file():
|
||||
|
||||
@@ -732,6 +732,82 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
||||
unet_conversion_map = [
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
||||
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
||||
("input_blocks.0.0.weight", "conv_in.weight"),
|
||||
("input_blocks.0.0.bias", "conv_in.bias"),
|
||||
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
|
||||
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
|
||||
]
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
("in_layers.0", "norm1"),
|
||||
("in_layers.2", "conv1"),
|
||||
("out_layers.0", "norm2"),
|
||||
("out_layers.3", "conv2"),
|
||||
("emb_layers.1", "time_emb_proj"),
|
||||
("skip_connection", "conv_shortcut"),
|
||||
]
|
||||
|
||||
unet_conversion_map_layer = []
|
||||
for i in range(4):
|
||||
for j in range(2):
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
controlnet_cond_embedding_names = (
|
||||
["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
||||
)
|
||||
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
||||
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
||||
sd_prefix = f"input_hint_block.{i*2}."
|
||||
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
||||
|
||||
for i in range(12):
|
||||
hf_prefix = f"controlnet_down_blocks.{i}."
|
||||
sd_prefix = f"zero_convs.{i}.0."
|
||||
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
||||
|
||||
mapping = {k: k for k in controlnet_state_dict.keys()}
|
||||
for sd_name, diffusers_name in unet_conversion_map:
|
||||
mapping[diffusers_name] = sd_name
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in k:
|
||||
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
||||
v = v.replace(diffusers_part, sd_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
for sd_part, diffusers_part in unet_conversion_map_layer:
|
||||
v = v.replace(diffusers_part, sd_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
|
||||
@@ -403,6 +403,54 @@ class FineTuningSubset(BaseSubset):
|
||||
return self.metadata_file == other.metadata_file
|
||||
|
||||
|
||||
class ControlNetSubset(BaseSubset):
|
||||
def __init__(
|
||||
self,
|
||||
image_dir: str,
|
||||
conditioning_data_dir: str,
|
||||
caption_extension: str,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
keep_tokens,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
super().__init__(
|
||||
image_dir,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
keep_tokens,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
random_crop,
|
||||
caption_dropout_rate,
|
||||
caption_dropout_every_n_epochs,
|
||||
caption_tag_dropout_rate,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, ControlNetSubset):
|
||||
return NotImplemented
|
||||
return self.image_dir == other.image_dir and self.conditioning_data_dir == other.conditioning_data_dir
|
||||
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, tokenizer: CLIPTokenizer, max_token_length: int, resolution: Optional[Tuple[int, int]], debug_dataset: bool
|
||||
@@ -1387,6 +1435,274 @@ class FineTuningDataset(BaseDataset):
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[ControlNetSubset],
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
self.conditioning_image_data: Dict[str, ImageInfo] = {}
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.latents_cache = None
|
||||
|
||||
self.num_reg_images = 0
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
tokens = base_name.split("_")
|
||||
if len(tokens) >= 5:
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
||||
|
||||
caption = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
return caption
|
||||
|
||||
def load_controlnet_dir(subset: ControlNetSubset):
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
print(f"not directory: {subset.image_dir}")
|
||||
return [], []
|
||||
if not os.path.isdir(subset.conditioning_data_dir):
|
||||
print(f"not directory: {subset.conditioning_data_dir}")
|
||||
return [], []
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
||||
img_paths = sorted(img_paths)
|
||||
conditioning_img_paths = sorted(conditioning_img_paths)
|
||||
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
print(f"found directory {subset.conditioning_data_dir} contains {len(conditioning_img_paths)} image files")
|
||||
|
||||
img_basenames = [os.path.basename(img) for img in img_paths]
|
||||
conditioning_img_basenames = [os.path.basename(img) for img in conditioning_img_paths]
|
||||
missing_imgs = []
|
||||
extra_imgs = []
|
||||
|
||||
for img in img_basenames:
|
||||
if img not in conditioning_img_basenames:
|
||||
missing_imgs.append(img)
|
||||
for img in conditioning_img_basenames:
|
||||
if img not in img_basenames:
|
||||
extra_imgs.append(img)
|
||||
|
||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if cap_for_img is None:
|
||||
print(f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}")
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
if missing_captions:
|
||||
number_of_missing_captions = len(missing_captions)
|
||||
number_of_missing_captions_to_show = 5
|
||||
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
|
||||
|
||||
print(
|
||||
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
|
||||
)
|
||||
for i, missing_caption in enumerate(missing_captions):
|
||||
if i >= number_of_missing_captions_to_show:
|
||||
print(missing_caption + f"... and {remaining_missing_captions} more")
|
||||
break
|
||||
print(missing_caption)
|
||||
return img_paths, conditioning_img_paths, captions
|
||||
|
||||
print("prepare images.")
|
||||
num_train_images = 0
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
print(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
|
||||
)
|
||||
continue
|
||||
|
||||
if subset in self.subsets:
|
||||
print(
|
||||
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
|
||||
)
|
||||
continue
|
||||
|
||||
img_paths, conditioning_img_paths, captions = load_controlnet_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
print(f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します")
|
||||
continue
|
||||
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, cond_img_path, caption in zip(img_paths, conditioning_img_paths, captions):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, False, img_path)
|
||||
setattr(info, "cond_img_path", cond_img_path)
|
||||
self.register_image(info, subset)
|
||||
|
||||
subset.img_count = len(img_paths)
|
||||
self.subsets.append(subset)
|
||||
|
||||
print(f"{num_train_images} train images with repeating.")
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
self.conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
||||
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
||||
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
||||
|
||||
loss_weights = []
|
||||
captions = []
|
||||
input_ids_list = []
|
||||
latents_list = []
|
||||
images = []
|
||||
conditioning_images = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(1.0)
|
||||
|
||||
# image/latentsを処理する
|
||||
if image_info.latents is not None: # cache_latents=Trueの場合
|
||||
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
|
||||
latents = torch.FloatTensor(latents)
|
||||
image = None
|
||||
else:
|
||||
# 画像を読み込み、必要ならcropする
|
||||
img = self.load_image(image_info.absolute_path)
|
||||
im_h, im_w = img.shape[0:2]
|
||||
|
||||
if self.enable_bucket:
|
||||
img = self.trim_and_resize_if_required(subset, img, image_info.bucket_reso, image_info.resized_size)
|
||||
else:
|
||||
im_h, im_w = img.shape[0:2]
|
||||
assert (
|
||||
im_h == self.height and im_w == self.width
|
||||
), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
|
||||
# augmentation
|
||||
aug = self.aug_helper.get_augmentor(subset.color_aug, subset.flip_aug)
|
||||
if aug is not None:
|
||||
img = aug(image=img)["image"]
|
||||
|
||||
latents = None
|
||||
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
captions.append(caption)
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer)
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption)
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
assert hasattr(image_info, "cond_img_path"), f"conditioning image path is not found: {image_info.absolute_path}"
|
||||
|
||||
cond_img = self.load_image(image_info.cond_img_path)
|
||||
if self.enable_bucket:
|
||||
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
|
||||
cond_img = self.conditioning_image_transforms(cond_img)
|
||||
conditioning_images.append(cond_img)
|
||||
conditioning_images = torch.stack(conditioning_images)
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example["input_ids"] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
else:
|
||||
# batch processing seems to be good
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
|
||||
if images[0] is not None:
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
else:
|
||||
images = None
|
||||
example["images"] = images
|
||||
|
||||
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
|
||||
example["captions"] = captions
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
|
||||
example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
return example
|
||||
|
||||
# behave as Dataset mock
|
||||
class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def __init__(self, datasets: Sequence[Union[DreamBoothDataset, FineTuningDataset]]):
|
||||
@@ -1636,6 +1952,8 @@ def get_git_revision_hash() -> str:
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
replace_attentions_for_hypernetwork()
|
||||
# unet is not used currently, but it is here for future use
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
return
|
||||
if mem_eff_attn:
|
||||
unet.set_attn_processor(FlashAttnProcessor())
|
||||
elif xformers:
|
||||
|
||||
594
train_controlnet.py
Normal file
594
train_controlnet.py
Normal file
@@ -0,0 +1,594 @@
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_controlnet_from_original_ckpt,
|
||||
)
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {
|
||||
"loss/current": current_loss,
|
||||
"loss/average": avr_loss,
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()):
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[0]["d"]
|
||||
* lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
)
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def train(args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_user_config = args.dataset_config is not None
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = random.randint(0, 2**32)
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
if use_user_config:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
||||
args.train_data_dir,
|
||||
args.conditioning_data_dir,
|
||||
args.caption_extension,
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
|
||||
blueprint.dataset_group
|
||||
)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = (
|
||||
train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
)
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
print(
|
||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(
|
||||
args, weight_dtype, accelerator
|
||||
)
|
||||
if args.controlnet_model_name_or_path:
|
||||
if os.path.isfile(args.controlnet_model_name_or_path):
|
||||
controlnet = download_controlnet_from_original_ckpt(
|
||||
args.controlnet_model_name_or_path
|
||||
)
|
||||
else:
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
args.controlnet_model_name_or_path
|
||||
)
|
||||
else:
|
||||
controlnet = ControlNetModel.from_unet(unet)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
vae,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
controlnet.enable_gradient_checkpointing()
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = controlnet.parameters()
|
||||
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(
|
||||
args, trainable_params
|
||||
)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(
|
||||
args.max_data_loader_n_workers, os.cpu_count() - 1
|
||||
) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collater,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader)
|
||||
/ accelerator.num_processes
|
||||
/ args.gradient_accumulation_steps
|
||||
)
|
||||
if is_main_process:
|
||||
print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(
|
||||
args, optimizer, accelerator.num_processes
|
||||
)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
print("enable full fp16 training.")
|
||||
controlnet.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.to(accelerator.device)
|
||||
text_encoder.to(accelerator.device)
|
||||
|
||||
# transform DDP after prepare
|
||||
controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet
|
||||
|
||||
controlnet.train()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / args.gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = (
|
||||
math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
)
|
||||
|
||||
# 学習する
|
||||
# TODO: find a way to handle total batch size when there are multiple datasets
|
||||
|
||||
if is_main_process:
|
||||
print("running training / 学習開始")
|
||||
print(
|
||||
f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}"
|
||||
)
|
||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(
|
||||
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
|
||||
)
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="steps",
|
||||
)
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers(
|
||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name
|
||||
)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
del train_dataset_group
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, model, steps, epoch_no, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
|
||||
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
|
||||
|
||||
if save_dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(ckpt_file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, ckpt_file)
|
||||
else:
|
||||
torch.save(state_dict, ckpt_file)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(
|
||||
args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload
|
||||
)
|
||||
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(controlnet):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(
|
||||
batch["images"].to(dtype=weight_dtype)
|
||||
).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.gethidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, weight_dtype
|
||||
)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
noise = apply_noise_offset(
|
||||
latents, noise, args.noise_offset, args.adaptive_noise_scale
|
||||
)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(
|
||||
noise,
|
||||
latents.device,
|
||||
args.multires_noise_iterations,
|
||||
args.multires_noise_discount,
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.config.num_train_timesteps,
|
||||
(b_size,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
||||
|
||||
with accelerator.autocast():
|
||||
down_block_res_samples, mid_block_res_sample = controlnet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
controlnet_cond=controlnet_image,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
down_block_additional_residuals=[
|
||||
sample.to(dtype=weight_dtype)
|
||||
for sample in down_block_res_samples
|
||||
],
|
||||
mid_block_additional_residual=mid_block_res_sample.to(
|
||||
dtype=weight_dtype
|
||||
),
|
||||
).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(
|
||||
noise_pred.float(), target.float(), reduction="none"
|
||||
)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(
|
||||
loss, timesteps, noise_scheduler, args.min_snr_gamma
|
||||
)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = controlnet.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if (
|
||||
args.save_every_n_steps is not None
|
||||
and global_step % args.save_every_n_steps == 0
|
||||
):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(
|
||||
args, "." + args.save_model_as, global_step
|
||||
)
|
||||
save_model(
|
||||
ckpt_name, unwrap_model(controlnet), global_step, epoch
|
||||
)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(
|
||||
args, accelerator, global_step
|
||||
)
|
||||
|
||||
remove_step_no = train_util.get_remove_step_no(
|
||||
args, global_step
|
||||
)
|
||||
if remove_step_no is not None:
|
||||
remove_ckpt_name = train_util.get_step_ckpt_name(
|
||||
args, "." + args.save_model_as, remove_step_no
|
||||
)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (
|
||||
epoch + 1
|
||||
) < num_train_epochs
|
||||
if is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(
|
||||
args, "." + args.save_model_as, epoch + 1
|
||||
)
|
||||
save_model(ckpt_name, unwrap_model(controlnet), global_step, epoch + 1)
|
||||
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
remove_ckpt_name = train_util.get_epoch_ckpt_name(
|
||||
args, "." + args.save_model_as, remove_epoch_no
|
||||
)
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(
|
||||
args, accelerator, epoch + 1
|
||||
)
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
)
|
||||
|
||||
# end of epoch
|
||||
if is_main_process:
|
||||
controlnet = unwrap_model(controlnet)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if is_main_process and args.save_state:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(
|
||||
ckpt_name, controlnet, global_step, num_train_epochs, force_sync_upload=True
|
||||
)
|
||||
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conditioning_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
Reference in New Issue
Block a user