Replace print with logger if they are logs (#905)

* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
Yuta Hayashibe
2024-02-04 16:14:34 +07:00
committed by GitHub
parent 7f948db158
commit 5f6bf29e52
62 changed files with 1195 additions and 961 deletions

View File

@@ -18,6 +18,10 @@ init_ipex()
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util import library.train_util as train_util
import library.config_util as config_util import library.config_util as config_util
from library.config_util import ( from library.config_util import (
@@ -49,11 +53,11 @@ def train(args):
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"] ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
@@ -86,7 +90,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print( logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
) )
return return
@@ -97,7 +101,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
@@ -461,7 +465,7 @@ def train(args):
train_util.save_sd_model_on_train_end( train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
) )
print("model saved.") logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -21,6 +21,10 @@ import torch.nn.functional as F
import os import os
from urllib.parse import urlparse from urllib.parse import urlparse
from timm.models.hub import download_cached_file from timm.models.hub import download_cached_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class BLIP_Base(nn.Module): class BLIP_Base(nn.Module):
def __init__(self, def __init__(self,
@@ -235,6 +239,6 @@ def load_checkpoint(model,url_or_filename):
del state_dict[key] del state_dict[key]
msg = model.load_state_dict(state_dict,strict=False) msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename) logger.info('load checkpoint from %s'%url_or_filename)
return model,msg return model,msg

View File

@@ -8,6 +8,10 @@ import json
import re import re
from tqdm import tqdm from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ') PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ') PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
@@ -36,13 +40,13 @@ def clean_tags(image_key, tags):
tokens = tags.split(", rating") tokens = tags.split(", rating")
if len(tokens) == 1: if len(tokens) == 1:
# WD14 taggerのときはこちらになるのでメッセージは出さない # WD14 taggerのときはこちらになるのでメッセージは出さない
# print("no rating:") # logger.info("no rating:")
# print(f"{image_key} {tags}") # logger.info(f"{image_key} {tags}")
pass pass
else: else:
if len(tokens) > 2: if len(tokens) > 2:
print("multiple ratings:") logger.info("multiple ratings:")
print(f"{image_key} {tags}") logger.info(f"{image_key} {tags}")
tags = tokens[0] tags = tokens[0]
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策 tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
@@ -124,43 +128,43 @@ def clean_caption(caption):
def main(args): def main(args):
if os.path.exists(args.in_json): if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}") logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f: with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f) metadata = json.load(f)
else: else:
print("no metadata / メタデータファイルがありません") logger.error("no metadata / メタデータファイルがありません")
return return
print("cleaning captions and tags.") logger.info("cleaning captions and tags.")
image_keys = list(metadata.keys()) image_keys = list(metadata.keys())
for image_key in tqdm(image_keys): for image_key in tqdm(image_keys):
tags = metadata[image_key].get('tags') tags = metadata[image_key].get('tags')
if tags is None: if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}") logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
else: else:
org = tags org = tags
tags = clean_tags(image_key, tags) tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags metadata[image_key]['tags'] = tags
if args.debug and org != tags: if args.debug and org != tags:
print("FROM: " + org) logger.info("FROM: " + org)
print("TO: " + tags) logger.info("TO: " + tags)
caption = metadata[image_key].get('caption') caption = metadata[image_key].get('caption')
if caption is None: if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}") logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else: else:
org = caption org = caption
caption = clean_caption(caption) caption = clean_caption(caption)
metadata[image_key]['caption'] = caption metadata[image_key]['caption'] = caption
if args.debug and org != caption: if args.debug and org != caption:
print("FROM: " + org) logger.info("FROM: " + org)
print("TO: " + caption) logger.info("TO: " + caption)
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f: with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2) json.dump(metadata, f, indent=2)
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
@@ -178,10 +182,10 @@ if __name__ == '__main__':
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
if len(unknown) == 1: if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.") logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
print("All captions and tags in the metadata are processed.") logger.warning("All captions and tags in the metadata are processed.")
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。") logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
print("メタデータ内のすべてのキャプションとタグが処理されます。") logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
args.in_json = args.out_json args.in_json = args.out_json
args.out_json = unknown[0] args.out_json = unknown[0]
elif len(unknown) > 0: elif len(unknown) > 0:

View File

@@ -15,6 +15,10 @@ from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url from blip.blip import blip_decoder, is_url
import library.train_util as train_util import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -47,7 +51,7 @@ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
# convert to tensor temporarily so dataloader will accept it # convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image) tensor = IMAGE_TRANSFORM(image)
except Exception as e: except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None return None
return (tensor, img_path) return (tensor, img_path)
@@ -74,21 +78,21 @@ def main(args):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
cwd = os.getcwd() cwd = os.getcwd()
print("Current Working Directory is: ", cwd) logger.info(f"Current Working Directory is: {cwd}")
os.chdir("finetune") os.chdir("finetune")
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights): if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
args.caption_weights = os.path.join("..", args.caption_weights) args.caption_weights = os.path.join("..", args.caption_weights)
print(f"load images from {args.train_data_dir}") logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") logger.info(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}") logger.info(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json") model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
model.eval() model.eval()
model = model.to(DEVICE) model = model.to(DEVICE)
print("BLIP loaded") logger.info("BLIP loaded")
# captioningする # captioningする
def run_batch(path_imgs): def run_batch(path_imgs):
@@ -108,7 +112,7 @@ def main(args):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n") f.write(caption + "\n")
if args.debug: if args.debug:
print(image_path, caption) logger.info(f'{image_path} {caption}')
# 読み込みの高速化のためにDataLoaderを使うオプション # 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None: if args.max_data_loader_n_workers is not None:
@@ -138,7 +142,7 @@ def main(args):
raw_image = raw_image.convert("RGB") raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image) img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e: except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue continue
b_imgs.append((image_path, img_tensor)) b_imgs.append((image_path, img_tensor))
@@ -148,7 +152,7 @@ def main(args):
if len(b_imgs) > 0: if len(b_imgs) > 0:
run_batch(b_imgs) run_batch(b_imgs)
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -10,7 +10,10 @@ from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin from transformers.generation.utils import GenerationMixin
import library.train_util as train_util import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -35,8 +38,8 @@ def remove_words(captions, debug):
for pat in PATTERN_REPLACE: for pat in PATTERN_REPLACE:
cap = pat.sub("", cap) cap = pat.sub("", cap)
if debug and cap != caption: if debug and cap != caption:
print(caption) logger.info(caption)
print(cap) logger.info(cap)
removed_caps.append(cap) removed_caps.append(cap)
return removed_caps return removed_caps
@@ -70,16 +73,16 @@ def main(args):
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
""" """
print(f"load images from {args.train_data_dir}") logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") logger.info(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい # できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}") logger.info(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id) git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE) git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded") logger.info("GIT loaded")
# captioningする # captioningする
def run_batch(path_imgs): def run_batch(path_imgs):
@@ -97,7 +100,7 @@ def main(args):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n") f.write(caption + "\n")
if args.debug: if args.debug:
print(image_path, caption) logger.info(f"{image_path} {caption}")
# 読み込みの高速化のためにDataLoaderを使うオプション # 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None: if args.max_data_loader_n_workers is not None:
@@ -126,7 +129,7 @@ def main(args):
if image.mode != "RGB": if image.mode != "RGB":
image = image.convert("RGB") image = image.convert("RGB")
except Exception as e: except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue continue
b_imgs.append((image_path, image)) b_imgs.append((image_path, image))
@@ -137,7 +140,7 @@ def main(args):
if len(b_imgs) > 0: if len(b_imgs) > 0:
run_batch(b_imgs) run_batch(b_imgs)
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -5,26 +5,30 @@ from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
import os import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args): def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file(): if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json args.in_json = args.out_json
if args.in_json is not None: if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}") logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます") logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else: else:
print("new metadata will be created / 新しいメタデータファイルが作成されます") logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {} metadata = {}
print("merge caption texts to metadata json.") logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths): for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension) caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip() caption = caption_path.read_text(encoding='utf-8').strip()
@@ -38,12 +42,12 @@ def main(args):
metadata[image_key]['caption'] = caption metadata[image_key]['caption'] = caption
if args.debug: if args.debug:
print(image_key, caption) logger.info(f"{image_key} {caption}")
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -5,26 +5,30 @@ from typing import List
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
import os import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args): def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください" assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file(): if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json args.in_json = args.out_json
if args.in_json is not None: if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}") logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8')) metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます") logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else: else:
print("new metadata will be created / 新しいメタデータファイルが作成されます") logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {} metadata = {}
print("merge tags to metadata json.") logger.info("merge tags to metadata json.")
for image_path in tqdm(image_paths): for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension) tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip() tags = tags_path.read_text(encoding='utf-8').strip()
@@ -38,13 +42,13 @@ def main(args):
metadata[image_key]['tags'] = tags metadata[image_key]['tags'] = tags
if args.debug: if args.debug:
print(image_key, tags) logger.info(f"{image_key} {tags}")
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8') Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -13,6 +13,10 @@ from torchvision import transforms
import library.model_util as model_util import library.model_util as model_util
import library.train_util as train_util import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -51,22 +55,22 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
def main(args): def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります" # assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0: if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
if args.bucket_reso_steps % 32 > 0: if args.bucket_reso_steps % 32 > 0:
print( logger.warning(
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません" f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
) )
train_data_dir_path = Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)] image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
print(f"found {len(image_paths)} images.") logger.info(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json): if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}") logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding="utf-8") as f: with open(args.in_json, "rt", encoding="utf-8") as f:
metadata = json.load(f) metadata = json.load(f)
else: else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}") logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}")
return return
weight_dtype = torch.float32 weight_dtype = torch.float32
@@ -89,7 +93,7 @@ def main(args):
if not args.bucket_no_upscale: if not args.bucket_no_upscale:
bucket_manager.make_buckets() bucket_manager.make_buckets()
else: else:
print( logger.warning(
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます" "min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
) )
@@ -130,7 +134,7 @@ def main(args):
if image.mode != "RGB": if image.mode != "RGB":
image = image.convert("RGB") image = image.convert("RGB")
except Exception as e: except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue continue
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0] image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
@@ -183,15 +187,15 @@ def main(args):
for i, reso in enumerate(bucket_manager.resos): for i, reso in enumerate(bucket_manager.resos):
count = bucket_counts.get(reso, 0) count = bucket_counts.get(reso, 0)
if count > 0: if count > 0:
print(f"bucket {i} {reso}: {count}") logger.info(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors) img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}") logger.info(f"mean ar error: {np.mean(img_ar_errors)}")
# metadataを書き出して終わり # metadataを書き出して終わり
print(f"writing metadata: {args.out_json}") logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding="utf-8") as f: with open(args.out_json, "wt", encoding="utf-8") as f:
json.dump(metadata, f, indent=2) json.dump(metadata, f, indent=2)
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -11,6 +11,10 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import library.train_util as train_util import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# from wd14 tagger # from wd14 tagger
IMAGE_SIZE = 448 IMAGE_SIZE = 448
@@ -58,7 +62,7 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
image = preprocess_image(image) image = preprocess_image(image)
tensor = torch.tensor(image) tensor = torch.tensor(image)
except Exception as e: except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}") logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None return None
return (tensor, img_path) return (tensor, img_path)
@@ -79,7 +83,7 @@ def main(args):
# depreacatedの警告が出るけどなくなったらその時 # depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download: if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES files = FILES
if args.onnx: if args.onnx:
files += FILES_ONNX files += FILES_ONNX
@@ -95,7 +99,7 @@ def main(args):
force_filename=file, force_filename=file,
) )
else: else:
print("using existing wd14 tagger model") logger.info("using existing wd14 tagger model")
# 画像を読み込む # 画像を読み込む
if args.onnx: if args.onnx:
@@ -103,8 +107,8 @@ def main(args):
import onnxruntime as ort import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx" onnx_path = f"{args.model_dir}/model.onnx"
print("Running wd14 tagger with onnx") logger.info("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}") logger.info(f"loading onnx model: {onnx_path}")
if not os.path.exists(onnx_path): if not os.path.exists(onnx_path):
raise Exception( raise Exception(
@@ -121,7 +125,7 @@ def main(args):
if args.batch_size != batch_size and type(batch_size) != str: if args.batch_size != batch_size and type(batch_size) != str:
# some rebatch model may use 'N' as dynamic axes # some rebatch model may use 'N' as dynamic axes
print( logger.warning(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
) )
args.batch_size = batch_size args.batch_size = batch_size
@@ -156,7 +160,7 @@ def main(args):
train_data_dir_path = Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") logger.info(f"found {len(image_paths)} images.")
tag_freq = {} tag_freq = {}
@@ -237,7 +241,10 @@ def main(args):
with open(caption_file, "wt", encoding="utf-8") as f: with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n") f.write(tag_text + "\n")
if args.debug: if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
# 読み込みの高速化のためにDataLoaderを使うオプション # 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None: if args.max_data_loader_n_workers is not None:
@@ -269,7 +276,7 @@ def main(args):
image = image.convert("RGB") image = image.convert("RGB")
image = preprocess_image(image) image = preprocess_image(image)
except Exception as e: except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}") logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue continue
b_imgs.append((image_path, image)) b_imgs.append((image_path, image))
@@ -284,11 +291,11 @@ def main(args):
if args.frequency_tags: if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True) sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
print("\nTag frequencies:") logger.info("Tag frequencies:")
for tag, freq in sorted_tags: for tag, freq in sorted_tags:
print(f"{tag}: {freq}") logger.info(f"{tag}: {freq}")
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -104,6 +104,10 @@ from library.original_unet import UNet2DConditionModel, InferUNet2DConditionMode
from library.original_unet import FlashAttentionFunction from library.original_unet import FlashAttentionFunction
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: # scheduler:
SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_START = 0.00085
@@ -139,12 +143,12 @@ USE_CUTOUTS = False
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn: if mem_eff_attn:
print("Enable memory efficient attention for U-Net") logger.info("Enable memory efficient attention for U-Net")
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
unet.set_use_memory_efficient_attention(False, True) unet.set_use_memory_efficient_attention(False, True)
elif xformers: elif xformers:
print("Enable xformers for U-Net") logger.info("Enable xformers for U-Net")
try: try:
import xformers.ops import xformers.ops
except ImportError: except ImportError:
@@ -152,7 +156,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
unet.set_use_memory_efficient_attention(True, False) unet.set_use_memory_efficient_attention(True, False)
elif sdpa: elif sdpa:
print("Enable SDPA for U-Net") logger.info("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False) unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True) unet.set_use_sdpa(True)
@@ -168,7 +172,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
def replace_vae_attn_to_memory_efficient(): def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states, **kwargs): def forward_flash_attn(self, hidden_states, **kwargs):
@@ -224,7 +228,7 @@ def replace_vae_attn_to_memory_efficient():
def replace_vae_attn_to_xformers(): def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers") logger.info("VAE: Attention.forward has been replaced to xformers")
import xformers.ops import xformers.ops
def forward_xformers(self, hidden_states, **kwargs): def forward_xformers(self, hidden_states, **kwargs):
@@ -280,7 +284,7 @@ def replace_vae_attn_to_xformers():
def replace_vae_attn_to_sdpa(): def replace_vae_attn_to_sdpa():
print("VAE: Attention.forward has been replaced to sdpa") logger.info("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs): def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states residual = hidden_states
@@ -684,7 +688,7 @@ class PipelineLike:
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None: if not do_classifier_free_guidance and negative_scale is not None:
print(f"negative_scale is ignored if guidance scalle <= 1.0") logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None negative_scale = None
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
@@ -766,11 +770,11 @@ class PipelineLike:
clip_text_input = prompt_tokens clip_text_input = prompt_tokens
if clip_text_input.shape[1] > self.tokenizer.model_max_length: if clip_text_input.shape[1] > self.tokenizer.model_max_length:
# TODO 75文字を超えたら警告を出す # TODO 75文字を超えたら警告を出す
print("trim text input", clip_text_input.shape) logger.info(f"trim text input {clip_text_input.shape}")
clip_text_input = torch.cat( clip_text_input = torch.cat(
[clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1 [clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1
) )
print("trimmed", clip_text_input.shape) logger.info(f"trimmed {clip_text_input.shape}")
for i, clip_prompt in enumerate(clip_prompts): for i, clip_prompt in enumerate(clip_prompts):
if clip_prompt is not None: # clip_promptがあれば上書きする if clip_prompt is not None: # clip_promptがあれば上書きする
@@ -1699,7 +1703,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
if word.strip() == "BREAK": if word.strip() == "BREAK":
# pad until next multiple of tokenizer's max token length # pad until next multiple of tokenizer's max token length
pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length) pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length)
print(f"BREAK pad_len: {pad_len}") logger.info(f"BREAK pad_len: {pad_len}")
for i in range(pad_len): for i in range(pad_len):
# v2のときEOSをつけるべきかどうかわからないぜ # v2のときEOSをつけるべきかどうかわからないぜ
# if i == 0: # if i == 0:
@@ -1729,7 +1733,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
tokens.append(text_token) tokens.append(text_token)
weights.append(text_weight) weights.append(text_weight)
if truncated: if truncated:
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights return tokens, weights
@@ -2041,7 +2045,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
elif len(count_range) == 2: elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])] count_range = [int(count_range[0]), int(count_range[1])]
else: else:
print(f"invalid count range: {count_range}") logger.warning(f"invalid count range: {count_range}")
count_range = [1, 1] count_range = [1, 1]
if count_range[0] > count_range[1]: if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]] count_range = [count_range[1], count_range[0]]
@@ -2111,7 +2115,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
# def load_clip_l14_336(dtype): # def load_clip_l14_336(dtype):
# print(f"loading CLIP: {CLIP_ID_L14_336}") # logger.info(f"loading CLIP: {CLIP_ID_L14_336}")
# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
# return text_encoder # return text_encoder
@@ -2158,9 +2162,9 @@ def main(args):
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
if args.v_parameterization and not args.v2: if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None: if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
# モデルを読み込む # モデルを読み込む
if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
@@ -2170,10 +2174,10 @@ def main(args):
use_stable_diffusion_format = os.path.isfile(args.ckpt) use_stable_diffusion_format = os.path.isfile(args.ckpt)
if use_stable_diffusion_format: if use_stable_diffusion_format:
print("load StableDiffusion checkpoint") logger.info("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else: else:
print("load Diffusers pretrained models") logger.info("load Diffusers pretrained models")
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = loading_pipe.text_encoder text_encoder = loading_pipe.text_encoder
vae = loading_pipe.vae vae = loading_pipe.vae
@@ -2196,21 +2200,21 @@ def main(args):
# VAEを読み込む # VAEを読み込む
if args.vae is not None: if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype) vae = model_util.load_vae(args.vae, dtype)
print("additional VAE loaded") logger.info("additional VAE loaded")
# # 置換するCLIPを読み込む # # 置換するCLIPを読み込む
# if args.replace_clip_l14_336: # if args.replace_clip_l14_336:
# text_encoder = load_clip_l14_336(dtype) # text_encoder = load_clip_l14_336(dtype)
# print(f"large clip {CLIP_ID_L14_336} is loaded") # logger.info(f"large clip {CLIP_ID_L14_336} is loaded")
if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale: if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale:
print("prepare clip model") logger.info("prepare clip model")
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype) clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype)
else: else:
clip_model = None clip_model = None
if args.vgg16_guidance_scale > 0.0: if args.vgg16_guidance_scale > 0.0:
print("prepare resnet model") logger.info("prepare resnet model")
vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1) vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1)
else: else:
vgg16_model = None vgg16_model = None
@@ -2222,7 +2226,7 @@ def main(args):
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
# tokenizerを読み込む # tokenizerを読み込む
print("loading tokenizer") logger.info("loading tokenizer")
if use_stable_diffusion_format: if use_stable_diffusion_format:
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
@@ -2281,7 +2285,7 @@ def main(args):
self.sampler_noises = noises self.sampler_noises = noises
def randn(self, shape, device=None, dtype=None, layout=None, generator=None): def randn(self, shape, device=None, dtype=None, layout=None, generator=None):
# print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) # logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}")
if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises):
noise = self.sampler_noises[self.sampler_noise_index] noise = self.sampler_noises[self.sampler_noise_index]
if shape != noise.shape: if shape != noise.shape:
@@ -2290,7 +2294,7 @@ def main(args):
noise = None noise = None
if noise == None: if noise == None:
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
self.sampler_noise_index += 1 self.sampler_noise_index += 1
@@ -2321,7 +2325,7 @@ def main(args):
# clip_sample=Trueにする # clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
print("set clip_sample to True") logger.info("set clip_sample to True")
scheduler.config.clip_sample = True scheduler.config.clip_sample = True
# deviceを決定する # deviceを決定する
@@ -2378,7 +2382,7 @@ def main(args):
network_merge = 0 network_merge = 0
for i, network_module in enumerate(args.network_module): for i, network_module in enumerate(args.network_module):
print("import network module:", network_module) logger.info(f"import network module: {network_module}")
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
@@ -2396,7 +2400,7 @@ def main(args):
raise ValueError("No weight. Weight is required.") raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i] network_weight = args.network_weights[i]
print("load network weights from:", network_weight) logger.info(f"load network weights from: {network_weight}")
if model_util.is_safetensors(network_weight) and args.network_show_meta: if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open from safetensors.torch import safe_open
@@ -2404,7 +2408,7 @@ def main(args):
with safe_open(network_weight, framework="pt") as f: with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
if metadata is not None: if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}") logger.info(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights( network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
@@ -2414,20 +2418,20 @@ def main(args):
mergeable = network.is_mergeable() mergeable = network.is_mergeable()
if network_merge and not mergeable: if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.") logger.warning("network is not mergiable. ignore merge option.")
if not mergeable or i >= network_merge: if not mergeable or i >= network_merge:
# not merging # not merging
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}") logger.info(f"weights are loaded: {info}")
if args.opt_channels_last: if args.opt_channels_last:
network.to(memory_format=torch.channels_last) network.to(memory_format=torch.channels_last)
network.to(dtype).to(device) network.to(dtype).to(device)
if network_pre_calc: if network_pre_calc:
print("backup original weights") logger.info("backup original weights")
network.backup_weights() network.backup_weights()
networks.append(network) networks.append(network)
@@ -2441,7 +2445,7 @@ def main(args):
# upscalerの指定があれば取得する # upscalerの指定があれば取得する
upscaler = None upscaler = None
if args.highres_fix_upscaler: if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler) logger.info(f"import upscaler module {args.highres_fix_upscaler}")
imported_module = importlib.import_module(args.highres_fix_upscaler) imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {} us_kwargs = {}
@@ -2450,7 +2454,7 @@ def main(args):
key, value = net_arg.split("=") key, value = net_arg.split("=")
us_kwargs[key] = value us_kwargs[key] = value
print("create upscaler") logger.info("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs) upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device) upscaler.to(dtype).to(device)
@@ -2467,7 +2471,7 @@ def main(args):
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
if args.opt_channels_last: if args.opt_channels_last:
print(f"set optimizing: channels last") logger.info(f"set optimizing: channels last")
text_encoder.to(memory_format=torch.channels_last) text_encoder.to(memory_format=torch.channels_last)
vae.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last)
unet.to(memory_format=torch.channels_last) unet.to(memory_format=torch.channels_last)
@@ -2499,7 +2503,7 @@ def main(args):
args.vgg16_guidance_layer, args.vgg16_guidance_layer,
) )
pipe.set_control_nets(control_nets) pipe.set_control_nets(control_nets)
print("pipeline is ready.") logger.info("pipeline is ready.")
if args.diffusers_xformers: if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
@@ -2542,7 +2546,7 @@ def main(args):
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings) token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}") logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
assert ( assert (
min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1 min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
), f"token ids is not ordered" ), f"token ids is not ordered"
@@ -2601,7 +2605,7 @@ def main(args):
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings) token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") logger.info(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
# if num_vectors_per_token > 1: # if num_vectors_per_token > 1:
pipe.add_token_replacement(token_ids[0], token_ids) pipe.add_token_replacement(token_ids[0], token_ids)
@@ -2626,7 +2630,7 @@ def main(args):
# promptを取得する # promptを取得する
if args.from_file is not None: if args.from_file is not None:
print(f"reading prompts from {args.from_file}") logger.info(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f: with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines() prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0] prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
@@ -2655,7 +2659,7 @@ def main(args):
for p in paths: for p in paths:
image = Image.open(p) image = Image.open(p)
if image.mode != "RGB": if image.mode != "RGB":
print(f"convert image to RGB from {image.mode}: {p}") logger.info(f"convert image to RGB from {image.mode}: {p}")
image = image.convert("RGB") image = image.convert("RGB")
images.append(image) images.append(image)
@@ -2671,24 +2675,24 @@ def main(args):
return resized return resized
if args.image_path is not None: if args.image_path is not None:
print(f"load image for img2img: {args.image_path}") logger.info(f"load image for img2img: {args.image_path}")
init_images = load_images(args.image_path) init_images = load_images(args.image_path)
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
print(f"loaded {len(init_images)} images for img2img") logger.info(f"loaded {len(init_images)} images for img2img")
else: else:
init_images = None init_images = None
if args.mask_path is not None: if args.mask_path is not None:
print(f"load mask for inpainting: {args.mask_path}") logger.info(f"load mask for inpainting: {args.mask_path}")
mask_images = load_images(args.mask_path) mask_images = load_images(args.mask_path)
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
print(f"loaded {len(mask_images)} mask images for inpainting") logger.info(f"loaded {len(mask_images)} mask images for inpainting")
else: else:
mask_images = None mask_images = None
# promptがないとき、画像のPngInfoから取得する # promptがないとき、画像のPngInfoから取得する
if init_images is not None and len(prompt_list) == 0 and not args.interactive: if init_images is not None and len(prompt_list) == 0 and not args.interactive:
print("get prompts from images' meta data") logger.info("get prompts from images' meta data")
for img in init_images: for img in init_images:
if "prompt" in img.text: if "prompt" in img.text:
prompt = img.text["prompt"] prompt = img.text["prompt"]
@@ -2717,17 +2721,17 @@ def main(args):
h = int(h * args.highres_fix_scale + 0.5) h = int(h * args.highres_fix_scale + 0.5)
if init_images is not None: if init_images is not None:
print(f"resize img2img source images to {w}*{h}") logger.info(f"resize img2img source images to {w}*{h}")
init_images = resize_images(init_images, (w, h)) init_images = resize_images(init_images, (w, h))
if mask_images is not None: if mask_images is not None:
print(f"resize img2img mask images to {w}*{h}") logger.info(f"resize img2img mask images to {w}*{h}")
mask_images = resize_images(mask_images, (w, h)) mask_images = resize_images(mask_images, (w, h))
regional_network = False regional_network = False
if networks and mask_images: if networks and mask_images:
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True regional_network = True
print("use mask as region") logger.info("use mask as region")
size = None size = None
for i, network in enumerate(networks): for i, network in enumerate(networks):
@@ -2752,14 +2756,14 @@ def main(args):
prev_image = None # for VGG16 guided prev_image = None # for VGG16 guided
if args.guide_image_path is not None: if args.guide_image_path is not None:
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}") logger.info(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
guide_images = [] guide_images = []
for p in args.guide_image_path: for p in args.guide_image_path:
guide_images.extend(load_images(p)) guide_images.extend(load_images(p))
print(f"loaded {len(guide_images)} guide images for guidance") logger.info(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0: if len(guide_images) == 0:
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") logger.info(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
guide_images = None guide_images = None
else: else:
guide_images = None guide_images = None
@@ -2785,7 +2789,7 @@ def main(args):
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for gen_iter in range(args.n_iter): for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}") logger.info(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7FFFFFFF) iter_seed = random.randint(0, 0x7FFFFFFF)
# shuffle prompt list # shuffle prompt list
@@ -2801,7 +2805,7 @@ def main(args):
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す # 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
print("process 1st stage") logger.info("process 1st stage")
batch_1st = [] batch_1st = []
for _, base, ext in batch: for _, base, ext in batch:
width_1st = int(ext.width * args.highres_fix_scale + 0.5) width_1st = int(ext.width * args.highres_fix_scale + 0.5)
@@ -2827,7 +2831,7 @@ def main(args):
images_1st = process_batch(batch_1st, True, True) images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する # 2nd stageのバッチを作成して以下処理する
print("process 2nd stage") logger.info("process 2nd stage")
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler: if upscaler:
@@ -2978,7 +2982,7 @@ def main(args):
n.restore_weights() n.restore_weights()
for n in networks: for n in networks:
n.pre_calculation() n.pre_calculation()
print("pre-calculation... done") logger.info("pre-calculation... done")
images = pipe( images = pipe(
prompts, prompts,
@@ -3045,7 +3049,7 @@ def main(args):
cv2.waitKey() cv2.waitKey()
cv2.destroyAllWindows() cv2.destroyAllWindows()
except ImportError: except ImportError:
print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") logger.info("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
return images return images
@@ -3058,7 +3062,8 @@ def main(args):
# interactive # interactive
valid = False valid = False
while not valid: while not valid:
print("\nType prompt:") logger.info("")
logger.info("Type prompt:")
try: try:
raw_prompt = input() raw_prompt = input()
except EOFError: except EOFError:
@@ -3101,38 +3106,38 @@ def main(args):
prompt_args = raw_prompt.strip().split(" --") prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0] prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
for parg in prompt_args[1:]: for parg in prompt_args[1:]:
try: try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE) m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m: if m:
width = int(m.group(1)) width = int(m.group(1))
print(f"width: {width}") logger.info(f"width: {width}")
continue continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE) m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m: if m:
height = int(m.group(1)) height = int(m.group(1))
print(f"height: {height}") logger.info(f"height: {height}")
continue continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE) m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps if m: # steps
steps = max(1, min(1000, int(m.group(1)))) steps = max(1, min(1000, int(m.group(1))))
print(f"steps: {steps}") logger.info(f"steps: {steps}")
continue continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed if m: # seed
seeds = [int(d) for d in m.group(1).split(",")] seeds = [int(d) for d in m.group(1).split(",")]
print(f"seeds: {seeds}") logger.info(f"seeds: {seeds}")
continue continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale if m: # scale
scale = float(m.group(1)) scale = float(m.group(1))
print(f"scale: {scale}") logger.info(f"scale: {scale}")
continue continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
@@ -3141,25 +3146,25 @@ def main(args):
negative_scale = None negative_scale = None
else: else:
negative_scale = float(m.group(1)) negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}") logger.info(f"negative scale: {negative_scale}")
continue continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength if m: # strength
strength = float(m.group(1)) strength = float(m.group(1))
print(f"strength: {strength}") logger.info(f"strength: {strength}")
continue continue
m = re.match(r"n (.+)", parg, re.IGNORECASE) m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt if m: # negative prompt
negative_prompt = m.group(1) negative_prompt = m.group(1)
print(f"negative prompt: {negative_prompt}") logger.info(f"negative prompt: {negative_prompt}")
continue continue
m = re.match(r"c (.+)", parg, re.IGNORECASE) m = re.match(r"c (.+)", parg, re.IGNORECASE)
if m: # clip prompt if m: # clip prompt
clip_prompt = m.group(1) clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}") logger.info(f"clip prompt: {clip_prompt}")
continue continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
@@ -3167,47 +3172,47 @@ def main(args):
network_muls = [float(v) for v in m.group(1).split(",")] network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks): while len(network_muls) < len(networks):
network_muls.append(network_muls[-1]) network_muls.append(network_muls[-1])
print(f"network mul: {network_muls}") logger.info(f"network mul: {network_muls}")
continue continue
# Deep Shrink # Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1 if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1)) ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}") logger.info(f"deep shrink depth 1: {ds_depth_1}")
continue continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1 if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1)) ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}") logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2 if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1)) ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}") logger.info(f"deep shrink depth 2: {ds_depth_2}")
continue continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2 if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1)) ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}") logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio if m: # deep shrink ratio
ds_ratio = float(m.group(1)) ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}") logger.info(f"deep shrink ratio: {ds_ratio}")
continue continue
except ValueError as ex: except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}") logger.info(f"Exception in parsing / 解析エラー: {parg}")
print(ex) logger.info(ex)
# override Deep Shrink # override Deep Shrink
if ds_depth_1 is not None: if ds_depth_1 is not None:
@@ -3225,7 +3230,7 @@ def main(args):
if len(predefined_seeds) > 0: if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0) seed = predefined_seeds.pop(0)
else: else:
print("predefined seeds are exhausted") logger.info("predefined seeds are exhausted")
seed = None seed = None
elif args.iter_same_seed: elif args.iter_same_seed:
seed = iter_seed seed = iter_seed
@@ -3235,7 +3240,7 @@ def main(args):
if seed is None: if seed is None:
seed = random.randint(0, 0x7FFFFFFF) seed = random.randint(0, 0x7FFFFFFF)
if args.interactive: if args.interactive:
print(f"seed: {seed}") logger.info(f"seed: {seed}")
# prepare init image, guide image and mask # prepare init image, guide image and mask
init_image = mask_image = guide_image = None init_image = mask_image = guide_image = None
@@ -3251,7 +3256,7 @@ def main(args):
width = width - width % 32 width = width - width % 32
height = height - height % 32 height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]: if width != init_image.size[0] or height != init_image.size[1]:
print( logger.info(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
) )
@@ -3267,9 +3272,9 @@ def main(args):
guide_image = guide_images[global_step % len(guide_images)] guide_image = guide_images[global_step % len(guide_images)]
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0: elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
if prev_image is None: if prev_image is None:
print("Generate 1st image without guide image.") logger.info("Generate 1st image without guide image.")
else: else:
print("Use previous image as guide image.") logger.info("Use previous image as guide image.")
guide_image = prev_image guide_image = prev_image
if regional_network: if regional_network:
@@ -3311,7 +3316,7 @@ def main(args):
process_batch(batch_data, highres_fix) process_batch(batch_data, highres_fix)
batch_data.clear() batch_data.clear()
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -40,7 +40,10 @@ from .train_util import (
ControlNetDataset, ControlNetDataset,
DatasetGroup, DatasetGroup,
) )
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def add_config_arguments(parser: argparse.ArgumentParser): def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル") parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
@@ -345,7 +348,7 @@ class ConfigSanitizer:
return self.user_config_validator(user_config) return self.user_config_validator(user_config)
except MultipleInvalid: except MultipleInvalid:
# TODO: エラー発生時のメッセージをわかりやすくする # TODO: エラー発生時のメッセージをわかりやすくする
print("Invalid user config / ユーザ設定の形式が正しくないようです") logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
raise raise
# NOTE: In nature, argument parser result is not needed to be sanitize # NOTE: In nature, argument parser result is not needed to be sanitize
@@ -355,7 +358,7 @@ class ConfigSanitizer:
return self.argparse_config_validator(argparse_namespace) return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid: except MultipleInvalid:
# XXX: this should be a bug # XXX: this should be a bug
print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。") logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
raise raise
# NOTE: value would be overwritten by latter dict if there is already the same key # NOTE: value would be overwritten by latter dict if there is already the same key
@@ -538,13 +541,13 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
" ", " ",
) )
print(info) logger.info(f'{info}')
# make buckets first because it determines the length of dataset # make buckets first because it determines the length of dataset
# and set the same seed for all datasets # and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets): for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]") logger.info(f"[Dataset {i}]")
dataset.make_buckets() dataset.make_buckets()
dataset.set_seed(seed) dataset.set_seed(seed)
@@ -557,7 +560,7 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
try: try:
n_repeats = int(tokens[0]) n_repeats = int(tokens[0])
except ValueError as e: except ValueError as e:
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
return 0, "" return 0, ""
caption_by_folder = "_".join(tokens[1:]) caption_by_folder = "_".join(tokens[1:])
return n_repeats, caption_by_folder return n_repeats, caption_by_folder
@@ -629,17 +632,13 @@ def load_user_config(file: str) -> dict:
with open(file, "r") as f: with open(file, "r") as f:
config = json.load(f) config = json.load(f)
except Exception: except Exception:
print( logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise raise
elif file.name.lower().endswith(".toml"): elif file.name.lower().endswith(".toml"):
try: try:
config = toml.load(file) config = toml.load(file)
except Exception: except Exception:
print( logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise raise
else: else:
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}") raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
@@ -665,23 +664,26 @@ if __name__ == "__main__":
argparse_namespace = parser.parse_args(remain) argparse_namespace = parser.parse_args(remain)
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning) train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
print("[argparse_namespace]") logger.info("[argparse_namespace]")
print(vars(argparse_namespace)) logger.info(f'{vars(argparse_namespace)}')
user_config = load_user_config(config_args.dataset_config) user_config = load_user_config(config_args.dataset_config)
print("\n[user_config]") logger.info("")
print(user_config) logger.info("[user_config]")
logger.info(f'{user_config}')
sanitizer = ConfigSanitizer( sanitizer = ConfigSanitizer(
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
) )
sanitized_user_config = sanitizer.sanitize_user_config(user_config) sanitized_user_config = sanitizer.sanitize_user_config(user_config)
print("\n[sanitized_user_config]") logger.info("")
print(sanitized_user_config) logger.info("[sanitized_user_config]")
logger.info(f'{sanitized_user_config}')
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace) blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
print("\n[blueprint]") logger.info("")
print(blueprint) logger.info("[blueprint]")
logger.info(f'{blueprint}')

View File

@@ -3,7 +3,10 @@ import argparse
import random import random
import re import re
from typing import List, Optional, Union from typing import List, Optional, Union
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def prepare_scheduler_for_custom_training(noise_scheduler, device): def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"): if hasattr(noise_scheduler, "all_snr"):
@@ -21,7 +24,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR # fix beta: zero terminal SNR
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
def enforce_zero_terminal_snr(betas): def enforce_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt # Convert betas to alphas_bar_sqrt
@@ -49,8 +52,8 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
# print("original:", noise_scheduler.betas) # logger.info(f"original: {noise_scheduler.betas}")
# print("fixed:", betas) # logger.info(f"fixed: {betas}")
noise_scheduler.betas = betas noise_scheduler.betas = betas
noise_scheduler.alphas = alphas noise_scheduler.alphas = alphas
@@ -79,13 +82,13 @@ def get_snr_scale(timesteps, noise_scheduler):
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1) scale = snr_t / (snr_t + 1)
# # show debug info # # show debug info
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
return scale return scale
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
scale = get_snr_scale(timesteps, noise_scheduler) scale = get_snr_scale(timesteps, noise_scheduler)
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss loss = loss + loss / scale * v_pred_like_loss
return loss return loss
@@ -268,7 +271,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
tokens.append(text_token) tokens.append(text_token)
weights.append(text_weight) weights.append(text_weight)
if truncated: if truncated:
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights return tokens, weights

View File

@@ -4,7 +4,10 @@ from pathlib import Path
import argparse import argparse
import os import os
from library.utils import fire_in_thread from library.utils import fire_in_thread
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None): def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
api = HfApi( api = HfApi(
@@ -33,9 +36,9 @@ def upload(
try: try:
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private) api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
print("===========================================") logger.error("===========================================")
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}") logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
print("===========================================") logger.error("===========================================")
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir()) is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
@@ -56,9 +59,9 @@ def upload(
path_in_repo=path_in_repo, path_in_repo=path_in_repo,
) )
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
print("===========================================") logger.error("===========================================")
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}") logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
print("===========================================") logger.error("===========================================")
if args.async_upload and not force_sync_upload: if args.async_upload and not force_sync_upload:
fire_in_thread(uploader) fire_in_thread(uploader)

View File

@@ -12,7 +12,7 @@ device_supports_fp64 = torch.xpu.has_fp64_dtype()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1: if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices") logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu") return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument def return_null_context(*args, **kwargs): # pylint: disable=unused-argument

View File

@@ -17,7 +17,6 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.utils import logging from diffusers.utils import logging
try: try:
from diffusers.utils import PIL_INTERPOLATION from diffusers.utils import PIL_INTERPOLATION
except ImportError: except ImportError:
@@ -626,7 +625,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
print(height, width) logger.info(f'{height} {width}')
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if (callback_steps is None) or (

View File

@@ -13,6 +13,10 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# DiffUsers版StableDiffusionのモデルパラメータ # DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000 NUM_TRAIN_TIMESTEPS = 1000
@@ -944,7 +948,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
for weight_name in weights_to_convert: for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k: if f"mid.attn_1.{weight_name}.weight" in k:
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
new_state_dict[k] = reshape_weight_for_sd(v) new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict return new_state_dict
@@ -1002,7 +1006,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
unet = UNet2DConditionModel(**unet_config).to(device) unet = UNet2DConditionModel(**unet_config).to(device)
info = unet.load_state_dict(converted_unet_checkpoint) info = unet.load_state_dict(converted_unet_checkpoint)
print("loading u-net:", info) logger.info(f"loading u-net: {info}")
# Convert the VAE model. # Convert the VAE model.
vae_config = create_vae_diffusers_config() vae_config = create_vae_diffusers_config()
@@ -1010,7 +1014,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
vae = AutoencoderKL(**vae_config).to(device) vae = AutoencoderKL(**vae_config).to(device)
info = vae.load_state_dict(converted_vae_checkpoint) info = vae.load_state_dict(converted_vae_checkpoint)
print("loading vae:", info) logger.info(f"loading vae: {info}")
# convert text_model # convert text_model
if v2: if v2:
@@ -1044,7 +1048,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
# logging.set_verbosity_error() # don't show annoying warning # logging.set_verbosity_error() # don't show annoying warning
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
# logging.set_verbosity_warning() # logging.set_verbosity_warning()
# print(f"config: {text_model.config}") # logger.info(f"config: {text_model.config}")
cfg = CLIPTextConfig( cfg = CLIPTextConfig(
vocab_size=49408, vocab_size=49408,
hidden_size=768, hidden_size=768,
@@ -1067,7 +1071,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
) )
text_model = CLIPTextModel._from_config(cfg) text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint) info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info) logger.info(f"loading text encoder: {info}")
return text_model, vae, unet return text_model, vae, unet
@@ -1142,7 +1146,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
# 最後の層などを捏造するか # 最後の層などを捏造するか
if make_dummy_weights: if make_dummy_weights:
print("make dummy weights for resblock.23, text_projection and logit scale.") logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
keys = list(new_sd.keys()) keys = list(new_sd.keys())
for key in keys: for key in keys:
if key.startswith("transformer.resblocks.22."): if key.startswith("transformer.resblocks.22."):
@@ -1261,14 +1265,14 @@ VAE_PREFIX = "first_stage_model."
def load_vae(vae_id, dtype): def load_vae(vae_id, dtype):
print(f"load VAE: {vae_id}") logger.info(f"load VAE: {vae_id}")
if os.path.isdir(vae_id) or not os.path.isfile(vae_id): if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
# Diffusers local/remote # Diffusers local/remote
try: try:
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
except EnvironmentError as e: except EnvironmentError as e:
print(f"exception occurs in loading vae: {e}") logger.error(f"exception occurs in loading vae: {e}")
print("retry with subfolder='vae'") logger.error("retry with subfolder='vae'")
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
return vae return vae
@@ -1340,13 +1344,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
if __name__ == "__main__": if __name__ == "__main__":
resos = make_bucket_resolutions((512, 768)) resos = make_bucket_resolutions((512, 768))
print(len(resos)) logger.info(f"{len(resos)}")
print(resos) logger.info(f"{resos}")
aspect_ratios = [w / h for w, h in resos] aspect_ratios = [w / h for w, h in resos]
print(aspect_ratios) logger.info(f"{aspect_ratios}")
ars = set() ars = set()
for ar in aspect_ratios: for ar in aspect_ratios:
if ar in ars: if ar in ars:
print("error! duplicate ar:", ar) logger.error(f"error! duplicate ar: {ar}")
ars.add(ar) ars.add(ar)

View File

@@ -113,6 +113,10 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from einops import rearrange from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280) BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0] TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
@@ -1380,7 +1384,7 @@ class UNet2DConditionModel(nn.Module):
): ):
super().__init__() super().__init__()
assert sample_size is not None, "sample_size must be specified" assert sample_size is not None, "sample_size must be specified"
print( logger.info(
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}" f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
) )
@@ -1514,7 +1518,7 @@ class UNet2DConditionModel(nn.Module):
def set_gradient_checkpointing(self, value=False): def set_gradient_checkpointing(self, value=False):
modules = self.down_blocks + [self.mid_block] + self.up_blocks modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules: for module in modules:
print(module.__class__.__name__, module.gradient_checkpointing, "->", value) logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
module.gradient_checkpointing = value module.gradient_checkpointing = value
# endregion # endregion
@@ -1709,14 +1713,14 @@ class InferUNet2DConditionModel:
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None: if ds_depth_1 is None:
print("Deep Shrink is disabled.") logger.info("Deep Shrink is disabled.")
self.ds_depth_1 = None self.ds_depth_1 = None
self.ds_timesteps_1 = None self.ds_timesteps_1 = None
self.ds_depth_2 = None self.ds_depth_2 = None
self.ds_timesteps_2 = None self.ds_timesteps_2 = None
self.ds_ratio = None self.ds_ratio = None
else: else:
print( logger.info(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
) )
self.ds_depth_1 = ds_depth_1 self.ds_depth_1 = ds_depth_1

View File

@@ -5,6 +5,10 @@ from io import BytesIO
import os import os
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import safetensors import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r""" r"""
# Metadata Example # Metadata Example
@@ -231,7 +235,7 @@ def build_metadata(
# # assert all values are filled # # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata # assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]): if not all([v is not None for v in metadata.values()]):
print(f"Internal error: some metadata values are None: {metadata}") logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata return metadata

View File

@@ -7,7 +7,10 @@ from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util from library import model_util
from library import sdxl_original_unet from library import sdxl_original_unet
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
VAE_SCALE_FACTOR = 0.13025 VAE_SCALE_FACTOR = 0.13025
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
@@ -131,7 +134,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
# temporary workaround for text_projection.weight.weight for Playground-v2 # temporary workaround for text_projection.weight.weight for Playground-v2
if "text_projection.weight.weight" in new_sd: if "text_projection.weight.weight" in new_sd:
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
del new_sd["text_projection.weight.weight"] del new_sd["text_projection.weight.weight"]
@@ -186,20 +189,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
checkpoint = None checkpoint = None
# U-Net # U-Net
print("building U-Net") logger.info("building U-Net")
with init_empty_weights(): with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint") logger.info("loading U-Net from checkpoint")
unet_sd = {} unet_sd = {}
for k in list(state_dict.keys()): for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."): if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype) info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
print("U-Net: ", info) logger.info(f"U-Net: {info}")
# Text Encoders # Text Encoders
print("building text encoders") logger.info("building text encoders")
# Text Encoder 1 is same to Stability AI's SDXL # Text Encoder 1 is same to Stability AI's SDXL
text_model1_cfg = CLIPTextConfig( text_model1_cfg = CLIPTextConfig(
@@ -252,7 +255,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
with init_empty_weights(): with init_empty_weights():
text_model2 = CLIPTextModelWithProjection(text_model2_cfg) text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
print("loading text encoders from checkpoint") logger.info("loading text encoders from checkpoint")
te1_sd = {} te1_sd = {}
te2_sd = {} te2_sd = {}
for k in list(state_dict.keys()): for k in list(state_dict.keys()):
@@ -266,22 +269,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
te1_sd.pop("text_model.embeddings.position_ids") te1_sd.pop("text_model.embeddings.position_ids")
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1) logger.info(f"text encoder 1: {info1}")
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32 info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
print("text encoder 2:", info2) logger.info(f"text encoder 2: {info2}")
# prepare vae # prepare vae
print("building VAE") logger.info("building VAE")
vae_config = model_util.create_vae_diffusers_config() vae_config = model_util.create_vae_diffusers_config()
with init_empty_weights(): with init_empty_weights():
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
print("loading VAE from checkpoint") logger.info("loading VAE from checkpoint")
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config) converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype) info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
print("VAE:", info) logger.info(f"VAE: {info}")
ckpt_info = (epoch, global_step) if epoch is not None else None ckpt_info = (epoch, global_step) if epoch is not None else None
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info return text_model1, text_model2, vae, unet, logit_scale, ckpt_info

View File

@@ -30,7 +30,10 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from einops import rearrange from einops import rearrange
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4 IN_CHANNELS: int = 4
OUT_CHANNELS: int = 4 OUT_CHANNELS: int = 4
@@ -332,7 +335,7 @@ class ResnetBlock2D(nn.Module):
def forward(self, x, emb): def forward(self, x, emb):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
# print("ResnetBlock2D: gradient_checkpointing") # logger.info("ResnetBlock2D: gradient_checkpointing")
def create_custom_forward(func): def create_custom_forward(func):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -366,7 +369,7 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states): def forward(self, hidden_states):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
# print("Downsample2D: gradient_checkpointing") # logger.info("Downsample2D: gradient_checkpointing")
def create_custom_forward(func): def create_custom_forward(func):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -653,7 +656,7 @@ class BasicTransformerBlock(nn.Module):
def forward(self, hidden_states, context=None, timestep=None): def forward(self, hidden_states, context=None, timestep=None):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
# print("BasicTransformerBlock: checkpointing") # logger.info("BasicTransformerBlock: checkpointing")
def create_custom_forward(func): def create_custom_forward(func):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -796,7 +799,7 @@ class Upsample2D(nn.Module):
def forward(self, hidden_states, output_size=None): def forward(self, hidden_states, output_size=None):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
# print("Upsample2D: gradient_checkpointing") # logger.info("Upsample2D: gradient_checkpointing")
def create_custom_forward(func): def create_custom_forward(func):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -1046,7 +1049,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks: for block in blocks:
for module in block: for module in block:
if hasattr(module, "set_use_memory_efficient_attention"): if hasattr(module, "set_use_memory_efficient_attention"):
# print(module.__class__.__name__) # logger.info(module.__class__.__name__)
module.set_use_memory_efficient_attention(xformers, mem_eff) module.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa: bool) -> None: def set_use_sdpa(self, sdpa: bool) -> None:
@@ -1061,7 +1064,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks: for block in blocks:
for module in block.modules(): for module in block.modules():
if hasattr(module, "gradient_checkpointing"): if hasattr(module, "gradient_checkpointing"):
# print(module.__class__.__name__, module.gradient_checkpointing, "->", value) # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
module.gradient_checkpointing = value module.gradient_checkpointing = value
# endregion # endregion
@@ -1083,7 +1086,7 @@ class SdxlUNet2DConditionModel(nn.Module):
def call_module(module, h, emb, context): def call_module(module, h, emb, context):
x = h x = h
for layer in module: for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D): if isinstance(layer, ResnetBlock2D):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, Transformer2DModel): elif isinstance(layer, Transformer2DModel):
@@ -1135,14 +1138,14 @@ class InferSdxlUNet2DConditionModel:
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None: if ds_depth_1 is None:
print("Deep Shrink is disabled.") logger.info("Deep Shrink is disabled.")
self.ds_depth_1 = None self.ds_depth_1 = None
self.ds_timesteps_1 = None self.ds_timesteps_1 = None
self.ds_depth_2 = None self.ds_depth_2 = None
self.ds_timesteps_2 = None self.ds_timesteps_2 = None
self.ds_ratio = None self.ds_ratio = None
else: else:
print( logger.info(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
) )
self.ds_depth_1 = ds_depth_1 self.ds_depth_1 = ds_depth_1
@@ -1229,7 +1232,7 @@ class InferSdxlUNet2DConditionModel:
if __name__ == "__main__": if __name__ == "__main__":
import time import time
print("create unet") logger.info("create unet")
unet = SdxlUNet2DConditionModel() unet = SdxlUNet2DConditionModel()
unet.to("cuda") unet.to("cuda")
@@ -1238,7 +1241,7 @@ if __name__ == "__main__":
unet.train() unet.train()
# 使用メモリ量確認用の疑似学習ループ # 使用メモリ量確認用の疑似学習ループ
print("preparing optimizer") logger.info("preparing optimizer")
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
@@ -1253,12 +1256,12 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True) scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training") logger.info("start training")
steps = 10 steps = 10
batch_size = 1 batch_size = 1
for step in range(steps): for step in range(steps):
print(f"step {step}") logger.info(f"step {step}")
if step == 1: if step == 1:
time_start = time.perf_counter() time_start = time.perf_counter()
@@ -1278,4 +1281,4 @@ if __name__ == "__main__":
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter() time_end = time.perf_counter()
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps") logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")

View File

@@ -9,6 +9,10 @@ from tqdm import tqdm
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -21,7 +25,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16 model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes): for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index: if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
( (
load_stable_diffusion_format, load_stable_diffusion_format,
@@ -62,7 +66,7 @@ def _load_target_model(
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format: if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}") logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
( (
text_encoder1, text_encoder1,
text_encoder2, text_encoder2,
@@ -76,7 +80,7 @@ def _load_target_model(
from diffusers import StableDiffusionXLPipeline from diffusers import StableDiffusionXLPipeline
variant = "fp16" if weight_dtype == torch.float16 else None variant = "fp16" if weight_dtype == torch.float16 else None
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try: try:
try: try:
pipe = StableDiffusionXLPipeline.from_pretrained( pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -84,12 +88,12 @@ def _load_target_model(
) )
except EnvironmentError as ex: except EnvironmentError as ex:
if variant is not None: if variant is not None:
print("try to load fp32 model") logger.info("try to load fp32 model")
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None) pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
else: else:
raise ex raise ex
except EnvironmentError as ex: except EnvironmentError as ex:
print( logger.error(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
) )
raise ex raise ex
@@ -112,7 +116,7 @@ def _load_target_model(
with init_empty_weights(): with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype) sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
print("U-Net converted to original U-Net") logger.info("U-Net converted to original U-Net")
logit_scale = None logit_scale = None
ckpt_info = None ckpt_info = None
@@ -120,13 +124,13 @@ def _load_target_model(
# VAEを読み込む # VAEを読み込む
if vae_path is not None: if vae_path is not None:
vae = model_util.load_vae(vae_path, weight_dtype) vae = model_util.load_vae(vae_path, weight_dtype)
print("additional VAE loaded") logger.info("additional VAE loaded")
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def load_tokenizers(args: argparse.Namespace): def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers") logger.info("prepare tokenizers")
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH] original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
tokeniers = [] tokeniers = []
@@ -135,14 +139,14 @@ def load_tokenizers(args: argparse.Namespace):
if args.tokenizer_cache_dir: if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path): if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}") logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None: if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path) tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}") logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path) tokenizer.save_pretrained(local_tokenizer_path)
if i == 1: if i == 1:
@@ -151,7 +155,7 @@ def load_tokenizers(args: argparse.Namespace):
tokeniers.append(tokenizer) tokeniers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None: if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}") logger.info(f"update token length: {args.max_token_length}")
return tokeniers return tokeniers
@@ -332,23 +336,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization: if args.v_parameterization:
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None: if args.clip_skip is not None:
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
# if args.multires_noise_iterations: # if args.multires_noise_iterations:
# print( # logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります" # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
# ) # )
# else: # else:
# if args.noise_offset is None: # if args.noise_offset is None:
# args.noise_offset = DEFAULT_NOISE_OFFSET # args.noise_offset = DEFAULT_NOISE_OFFSET
# elif args.noise_offset != DEFAULT_NOISE_OFFSET: # elif args.noise_offset != DEFAULT_NOISE_OFFSET:
# print( # logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています" # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
# ) # )
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert ( assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions not hasattr(args, "weighted_captions") or not args.weighted_captions
@@ -357,7 +361,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
if supportTextEncoderCaching: if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True args.cache_text_encoder_outputs = True
print( logger.warning(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / " "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました" + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
) )

View File

@@ -26,7 +26,10 @@ from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.autoencoder_kl import AutoencoderKLOutput from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def slice_h(x, num_slices): def slice_h(x, num_slices):
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
@@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
# sliced_tensor = torch.chunk(x, num_div, dim=1) # sliced_tensor = torch.chunk(x, num_div, dim=1)
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0) # sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0) # sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
# normed_tensor = [] # normed_tensor = []
# for i in range(num_div): # for i in range(num_div):
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps) # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
@@ -243,7 +246,7 @@ class SlicingEncoder(nn.Module):
self.num_slices = num_slices self.num_slices = num_slices
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
# print(f"initial divisor: {div}") # logger.info(f"initial divisor: {div}")
if div >= 2: if div >= 2:
div = int(div) div = int(div)
for resnet in self.mid_block.resnets: for resnet in self.mid_block.resnets:
@@ -253,11 +256,11 @@ class SlicingEncoder(nn.Module):
for i, down_block in enumerate(self.down_blocks[::-1]): for i, down_block in enumerate(self.down_blocks[::-1]):
if div >= 2: if div >= 2:
div = int(div) div = int(div)
# print(f"down block: {i} divisor: {div}") # logger.info(f"down block: {i} divisor: {div}")
for resnet in down_block.resnets: for resnet in down_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div) resnet.forward = wrapper(resblock_forward, resnet, div)
if down_block.downsamplers is not None: if down_block.downsamplers is not None:
# print("has downsample") # logger.info("has downsample")
for downsample in down_block.downsamplers: for downsample in down_block.downsamplers:
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2) downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
div *= 2 div *= 2
@@ -307,7 +310,7 @@ class SlicingEncoder(nn.Module):
def downsample_forward(self, _self, num_slices, hidden_states): def downsample_forward(self, _self, num_slices, hidden_states):
assert hidden_states.shape[1] == _self.channels assert hidden_states.shape[1] == _self.channels
assert _self.use_conv and _self.padding == 0 assert _self.use_conv and _self.padding == 0
print("downsample forward", num_slices, hidden_states.shape) logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
org_device = hidden_states.device org_device = hidden_states.device
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
@@ -350,7 +353,7 @@ class SlicingEncoder(nn.Module):
hidden_states = torch.cat([hidden_states, x], dim=2) hidden_states = torch.cat([hidden_states, x], dim=2)
hidden_states = hidden_states.to(org_device) hidden_states = hidden_states.to(org_device)
# print("downsample forward done", hidden_states.shape) # logger.info(f"downsample forward done {hidden_states.shape}")
return hidden_states return hidden_states
@@ -426,7 +429,7 @@ class SlicingDecoder(nn.Module):
self.num_slices = num_slices self.num_slices = num_slices
div = num_slices / (2 ** (len(self.up_blocks) - 1)) div = num_slices / (2 ** (len(self.up_blocks) - 1))
print(f"initial divisor: {div}") logger.info(f"initial divisor: {div}")
if div >= 2: if div >= 2:
div = int(div) div = int(div)
for resnet in self.mid_block.resnets: for resnet in self.mid_block.resnets:
@@ -436,11 +439,11 @@ class SlicingDecoder(nn.Module):
for i, up_block in enumerate(self.up_blocks): for i, up_block in enumerate(self.up_blocks):
if div >= 2: if div >= 2:
div = int(div) div = int(div)
# print(f"up block: {i} divisor: {div}") # logger.info(f"up block: {i} divisor: {div}")
for resnet in up_block.resnets: for resnet in up_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div) resnet.forward = wrapper(resblock_forward, resnet, div)
if up_block.upsamplers is not None: if up_block.upsamplers is not None:
# print("has upsample") # logger.info("has upsample")
for upsample in up_block.upsamplers: for upsample in up_block.upsamplers:
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2) upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
div *= 2 div *= 2
@@ -528,7 +531,7 @@ class SlicingDecoder(nn.Module):
del x del x
hidden_states = torch.cat(sliced, dim=2) hidden_states = torch.cat(sliced, dim=2)
# print("us hidden_states", hidden_states.shape) # logger.info(f"us hidden_states {hidden_states.shape}")
del sliced del sliced
hidden_states = hidden_states.to(org_device) hidden_states = hidden_states.to(org_device)

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,23 @@
import logging
import threading import threading
from typing import * from typing import *
def fire_in_thread(f, *args, **kwargs): def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start() threading.Thread(target=f, args=args, kwargs=kwargs).start()
def setup_logging(log_level=logging.INFO):
if logging.root.handlers: # Already configured
return
from rich.logging import RichHandler
handler = RichHandler()
formatter = logging.Formatter(
fmt="%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logging.root.setLevel(log_level)
logging.root.addHandler(handler)

View File

@@ -2,10 +2,13 @@ import argparse
import os import os
import torch import torch
from safetensors.torch import load_file from safetensors.torch import load_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(file): def main(file):
print(f"loading: {file}") logger.info(f"loading: {file}")
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
sd = load_file(file) sd = load_file(file)
else: else:
@@ -17,16 +20,16 @@ def main(file):
for key in keys: for key in keys:
if "lora_up" in key or "lora_down" in key: if "lora_up" in key or "lora_down" in key:
values.append((key, sd[key])) values.append((key, sd[key]))
print(f"number of LoRA modules: {len(values)}") logger.info(f"number of LoRA modules: {len(values)}")
if args.show_all_keys: if args.show_all_keys:
for key in [k for k in keys if k not in values]: for key in [k for k in keys if k not in values]:
values.append((key, sd[key])) values.append((key, sd[key]))
print(f"number of all modules: {len(values)}") logger.info(f"number of all modules: {len(values)}")
for key, value in values: for key, value in values:
value = value.to(torch.float32) value = value.to(torch.float32)
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -2,7 +2,10 @@ import os
from typing import Optional, List, Type from typing import Optional, List, Type
import torch import torch
from library import sdxl_original_unet from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied # input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False SKIP_INPUT_BLOCKS = False
@@ -125,7 +128,7 @@ class LLLiteModule(torch.nn.Module):
return return
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") # logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
cx = self.conditioning1(cond_image) cx = self.conditioning1(cond_image)
if not self.is_conv2d: if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c # reshape / b,c,h,w -> b,h*w,c
@@ -155,7 +158,7 @@ class LLLiteModule(torch.nn.Module):
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
if self.use_zeros_for_batch_uncond: if self.use_zeros_for_batch_uncond:
cx[0::2] = 0.0 # uncond is zero cx[0::2] = 0.0 # uncond is zero
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") # logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
# downで入力の次元数を削減し、conditioning image embeddingと結合する # downで入力の次元数を削減し、conditioning image embeddingと結合する
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
@@ -286,7 +289,7 @@ class ControlNetLLLite(torch.nn.Module):
# create module instances # create module instances
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
def forward(self, x): def forward(self, x):
return x # dummy return x # dummy
@@ -319,7 +322,7 @@ class ControlNetLLLite(torch.nn.Module):
return info return info
def apply_to(self): def apply_to(self):
print("applying LLLite for U-Net...") logger.info("applying LLLite for U-Net...")
for module in self.unet_modules: for module in self.unet_modules:
module.apply_to() module.apply_to()
self.add_module(module.lllite_name, module) self.add_module(module.lllite_name, module)
@@ -374,19 +377,19 @@ if __name__ == "__main__":
# sdxl_original_unet.USE_REENTRANT = False # sdxl_original_unet.USE_REENTRANT = False
# test shape etc # test shape etc
print("create unet") logger.info("create unet")
unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet = sdxl_original_unet.SdxlUNet2DConditionModel()
unet.to("cuda").to(torch.float16) unet.to("cuda").to(torch.float16)
print("create ControlNet-LLLite") logger.info("create ControlNet-LLLite")
control_net = ControlNetLLLite(unet, 32, 64) control_net = ControlNetLLLite(unet, 32, 64)
control_net.apply_to() control_net.apply_to()
control_net.to("cuda") control_net.to("cuda")
print(control_net) logger.info(control_net)
# print number of parameters # logger.info number of parameters
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}")
input() input()
@@ -398,12 +401,12 @@ if __name__ == "__main__":
# # visualize # # visualize
# import torchviz # import torchviz
# print("run visualize") # logger.info("run visualize")
# controlnet.set_control(conditioning_image) # controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y) # output = unet(x, t, ctx, y)
# print("make_dot") # logger.info("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render") # logger.info("render")
# image.format = "svg" # "png" # image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input() # input()
@@ -414,12 +417,12 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True) scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training") logger.info("start training")
steps = 10 steps = 10
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
for step in range(steps): for step in range(steps):
print(f"step {step}") logger.info(f"step {step}")
batch_size = 1 batch_size = 1
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
@@ -439,7 +442,7 @@ if __name__ == "__main__":
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
print(sample_param) logger.info(f"{sample_param}")
# from safetensors.torch import save_file # from safetensors.torch import save_file

View File

@@ -6,7 +6,10 @@ import re
from typing import Optional, List, Type from typing import Optional, List, Type
import torch import torch
from library import sdxl_original_unet from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied # input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False SKIP_INPUT_BLOCKS = False
@@ -270,7 +273,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
# create module instances # create module instances
self.lllite_modules = apply_to_modules(self, target_modules) self.lllite_modules = apply_to_modules(self, target_modules)
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.") logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
# def prepare_optimizer_params(self): # def prepare_optimizer_params(self):
def prepare_params(self): def prepare_params(self):
@@ -281,8 +284,8 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
train_params.append(p) train_params.append(p)
else: else:
non_train_params.append(p) non_train_params.append(p)
print(f"count of trainable parameters: {len(train_params)}") logger.info(f"count of trainable parameters: {len(train_params)}")
print(f"count of non-trainable parameters: {len(non_train_params)}") logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
for p in non_train_params: for p in non_train_params:
p.requires_grad_(False) p.requires_grad_(False)
@@ -388,7 +391,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
matches = pattern.findall(module_name) matches = pattern.findall(module_name)
if matches is not None: if matches is not None:
for m in matches: for m in matches:
print(module_name, m) logger.info(f"{module_name} {m}")
module_name = module_name.replace(m, m.replace("_", "@")) module_name = module_name.replace(m, m.replace("_", "@"))
module_name = module_name.replace("_", ".") module_name = module_name.replace("_", ".")
module_name = module_name.replace("@", "_") module_name = module_name.replace("@", "_")
@@ -407,7 +410,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
def replace_unet_linear_and_conv2d(): def replace_unet_linear_and_conv2d():
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net") logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
sdxl_original_unet.torch.nn.Linear = LLLiteLinear sdxl_original_unet.torch.nn.Linear = LLLiteLinear
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
@@ -419,10 +422,10 @@ if __name__ == "__main__":
replace_unet_linear_and_conv2d() replace_unet_linear_and_conv2d()
# test shape etc # test shape etc
print("create unet") logger.info("create unet")
unet = SdxlUNet2DConditionModelControlNetLLLite() unet = SdxlUNet2DConditionModelControlNetLLLite()
print("enable ControlNet-LLLite") logger.info("enable ControlNet-LLLite")
unet.apply_lllite(32, 64, None, False, 1.0) unet.apply_lllite(32, 64, None, False, 1.0)
unet.to("cuda") # .to(torch.float16) unet.to("cuda") # .to(torch.float16)
@@ -439,14 +442,14 @@ if __name__ == "__main__":
# unet_sd[converted_key] = model_sd[key] # unet_sd[converted_key] = model_sd[key]
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd) # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
# print(info) # logger.info(info)
# print(unet) # logger.info(unet)
# print number of parameters # logger.info number of parameters
params = unet.prepare_params() params = unet.prepare_params()
print("number of parameters", sum(p.numel() for p in params)) logger.info(f"number of parameters {sum(p.numel() for p in params)}")
# print("type any key to continue") # logger.info("type any key to continue")
# input() # input()
unet.set_use_memory_efficient_attention(True, False) unet.set_use_memory_efficient_attention(True, False)
@@ -455,12 +458,12 @@ if __name__ == "__main__":
# # visualize # # visualize
# import torchviz # import torchviz
# print("run visualize") # logger.info("run visualize")
# controlnet.set_control(conditioning_image) # controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y) # output = unet(x, t, ctx, y)
# print("make_dot") # logger.info("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render") # logger.info("render")
# image.format = "svg" # "png" # image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input() # input()
@@ -471,13 +474,13 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True) scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training") logger.info("start training")
steps = 10 steps = 10
batch_size = 1 batch_size = 1
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0] sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
for step in range(steps): for step in range(steps):
print(f"step {step}") logger.info(f"step {step}")
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
x = torch.randn(batch_size, 4, 128, 128).cuda() x = torch.randn(batch_size, 4, 128, 128).cuda()
@@ -494,9 +497,9 @@ if __name__ == "__main__":
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
print(sample_param) logger.info(sample_param)
# from safetensors.torch import save_file # from safetensors.torch import save_file
# print("save weights") # logger.info("save weights")
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None) # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)

View File

@@ -15,7 +15,10 @@ import random
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class DyLoRAModule(torch.nn.Module): class DyLoRAModule(torch.nn.Module):
""" """
@@ -223,7 +226,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key: elif "lora_down" in key:
dim = value.size()[0] dim = value.size()[0]
modules_dim[lora_name] = dim modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim) # logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha # support old LoRA without alpha
for key in modules_dim.keys(): for key in modules_dim.keys():
@@ -267,11 +270,11 @@ class DyLoRANetwork(torch.nn.Module):
self.apply_to_conv = apply_to_conv self.apply_to_conv = apply_to_conv
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") logger.info("create LoRA network from weights")
else: else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}") logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
if self.apply_to_conv: if self.apply_to_conv:
print(f"apply LoRA to Conv2d with kernel size (3,3).") logger.info("apply LoRA to Conv2d with kernel size (3,3).")
# create module instances # create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]: def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
@@ -308,7 +311,7 @@ class DyLoRANetwork(torch.nn.Module):
return loras return loras
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -316,7 +319,7 @@ class DyLoRANetwork(torch.nn.Module):
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(True, unet, target_modules) self.unet_loras = create_modules(True, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_multiplier(self, multiplier): def set_multiplier(self, multiplier):
self.multiplier = multiplier self.multiplier = multiplier
@@ -336,12 +339,12 @@ class DyLoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder: if apply_text_encoder:
print("enable LoRA for text encoder") logger.info("enable LoRA for text encoder")
else: else:
self.text_encoder_loras = [] self.text_encoder_loras = []
if apply_unet: if apply_unet:
print("enable LoRA for U-Net") logger.info("enable LoRA for U-Net")
else: else:
self.unet_loras = [] self.unet_loras = []
@@ -359,12 +362,12 @@ class DyLoRANetwork(torch.nn.Module):
apply_unet = True apply_unet = True
if apply_text_encoder: if apply_text_encoder:
print("enable LoRA for text encoder") logger.info("enable LoRA for text encoder")
else: else:
self.text_encoder_loras = [] self.text_encoder_loras = []
if apply_unet: if apply_unet:
print("enable LoRA for U-Net") logger.info("enable LoRA for U-Net")
else: else:
self.unet_loras = [] self.unet_loras = []
@@ -375,7 +378,7 @@ class DyLoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device) lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged") logger.info(f"weights are merged")
""" """
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):

View File

@@ -10,7 +10,10 @@ from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm from tqdm import tqdm
from library import train_util, model_util from library import train_util, model_util
import numpy as np import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name): def load_state_dict(file_name):
if model_util.is_safetensors(file_name): if model_util.is_safetensors(file_name):
@@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit):
rank = value.size()[0] rank = value.size()[0]
if rank > max_rank: if rank > max_rank:
max_rank = rank max_rank = rank
print(f"Max rank: {max_rank}") logger.info(f"Max rank: {max_rank}")
rank = unit rank = unit
split_models = [] split_models = []
new_alpha = None new_alpha = None
while rank < max_rank: while rank < max_rank:
print(f"Splitting rank {rank}") logger.info(f"Splitting rank {rank}")
new_sd = {} new_sd = {}
for key, value in lora_sd.items(): for key, value in lora_sd.items():
if "lora_down" in key: if "lora_down" in key:
@@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit):
# なぜかscaleするとおかしくなる…… # なぜかscaleするとおかしくなる……
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
# scale = math.sqrt(this_rank / rank) # rank is > unit # scale = math.sqrt(this_rank / rank) # rank is > unit
# print(key, value.size(), this_rank, rank, value, scale) # logger.info(key, value.size(), this_rank, rank, value, scale)
# new_alpha = value * scale # always same # new_alpha = value * scale # always same
# new_sd[key] = new_alpha # new_sd[key] = new_alpha
new_sd[key] = value new_sd[key] = value
@@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit):
def split(args): def split(args):
print("loading Model...") logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model) lora_sd, metadata = load_state_dict(args.model)
print("Splitting Model...") logger.info("Splitting Model...")
original_rank, split_models = split_lora_model(lora_sd, args.unit) original_rank, split_models = split_lora_model(lora_sd, args.unit)
comment = metadata.get("ss_training_comment", "") comment = metadata.get("ss_training_comment", "")
@@ -94,7 +97,7 @@ def split(args):
filename, ext = os.path.splitext(args.save_to) filename, ext = os.path.splitext(args.save_to)
model_file_name = filename + f"-{new_rank:04d}{ext}" model_file_name = filename + f"-{new_rank:04d}{ext}"
print(f"saving model to: {model_file_name}") logger.info(f"saving model to: {model_file_name}")
save_to_file(model_file_name, state_dict, new_metadata) save_to_file(model_file_name, state_dict, new_metadata)

View File

@@ -11,7 +11,10 @@ from safetensors.torch import load_file, save_file
from tqdm import tqdm from tqdm import tqdm
from library import sai_model_spec, model_util, sdxl_model_util from library import sai_model_spec, model_util, sdxl_model_util
import lora import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# CLAMP_QUANTILE = 0.99 # CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1 # MIN_DIFF = 1e-1
@@ -66,14 +69,14 @@ def svd(
# load models # load models
if not sdxl: if not sdxl:
print(f"loading original SD model : {model_org}") logger.info(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o] text_encoders_o = [text_encoder_o]
if load_dtype is not None: if load_dtype is not None:
text_encoder_o = text_encoder_o.to(load_dtype) text_encoder_o = text_encoder_o.to(load_dtype)
unet_o = unet_o.to(load_dtype) unet_o = unet_o.to(load_dtype)
print(f"loading tuned SD model : {model_tuned}") logger.info(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t] text_encoders_t = [text_encoder_t]
if load_dtype is not None: if load_dtype is not None:
@@ -85,7 +88,7 @@ def svd(
device_org = load_original_model_to if load_original_model_to else "cpu" device_org = load_original_model_to if load_original_model_to else "cpu"
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu" device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
print(f"loading original SDXL model : {model_org}") logger.info(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
) )
@@ -95,7 +98,7 @@ def svd(
text_encoder_o2 = text_encoder_o2.to(load_dtype) text_encoder_o2 = text_encoder_o2.to(load_dtype)
unet_o = unet_o.to(load_dtype) unet_o = unet_o.to(load_dtype)
print(f"loading original SDXL model : {model_tuned}") logger.info(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
) )
@@ -135,7 +138,7 @@ def svd(
# Text Encoder might be same # Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diffs[lora_name] = diff diffs[lora_name] = diff
@@ -144,7 +147,7 @@ def svd(
del text_encoder del text_encoder
if not text_encoder_different: if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.") logger.warning("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = [] lora_network_o.text_encoder_loras = []
diffs = {} # clear diffs diffs = {} # clear diffs
@@ -166,7 +169,7 @@ def svd(
del unet_t del unet_t
# make LoRA with svd # make LoRA with svd
print("calculating by svd") logger.info("calculating by svd")
lora_weights = {} lora_weights = {}
with torch.no_grad(): with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())): for lora_name, mat in tqdm(list(diffs.items())):
@@ -185,7 +188,7 @@ def svd(
if device: if device:
mat = mat.to(device) mat = mat.to(device)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) # logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
if conv2d: if conv2d:
@@ -230,7 +233,7 @@ def svd(
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
info = lora_network_save.load_state_dict(lora_sd) info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}") logger.info(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(save_to) dir_name = os.path.dirname(save_to)
if dir_name and not os.path.exists(dir_name): if dir_name and not os.path.exists(dir_name):
@@ -257,7 +260,7 @@ def svd(
metadata.update(sai_metadata) metadata.update(sai_metadata)
lora_network_save.save_weights(save_to, save_dtype, metadata) lora_network_save.save_weights(save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {save_to}") logger.info(f"LoRA weights are saved to: {save_to}")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -11,7 +11,10 @@ from transformers import CLIPTextModel
import numpy as np import numpy as np
import torch import torch
import re import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -46,7 +49,7 @@ class LoRAModule(torch.nn.Module):
# if limit_rank: # if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim) # self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim: # if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else: # else:
self.lora_dim = lora_dim self.lora_dim = lora_dim
@@ -177,7 +180,7 @@ class LoRAInfModule(LoRAModule):
else: else:
# conv2d 3x3 # conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding) # logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale weight = weight + self.multiplier * conved * self.scale
# set weight to org_module # set weight to org_module
@@ -216,7 +219,7 @@ class LoRAInfModule(LoRAModule):
self.region_mask = None self.region_mask = None
def default_forward(self, x): def default_forward(self, x):
# print("default_forward", self.lora_name, x.size()) # logger.info(f"default_forward {self.lora_name} {x.size()}")
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x): def forward(self, x):
@@ -245,7 +248,7 @@ class LoRAInfModule(LoRAModule):
if mask is None: if mask is None:
# raise ValueError(f"mask is None for resolution {area}") # raise ValueError(f"mask is None for resolution {area}")
# emb_layers in SDXL doesn't have mask # emb_layers in SDXL doesn't have mask
# print(f"mask is None for resolution {area}, {x.size()}") # logger.info(f"mask is None for resolution {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4: if len(x.size()) != 4:
@@ -262,7 +265,7 @@ class LoRAInfModule(LoRAModule):
# apply mask for LoRA result # apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx) mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) # logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}")
lx = lx * mask lx = lx * mask
x = self.org_forward(x) x = self.org_forward(x)
@@ -291,7 +294,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond: if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :] query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
return query return query
def sub_prompt_forward(self, x): def sub_prompt_forward(self, x):
@@ -306,7 +309,7 @@ class LoRAInfModule(LoRAModule):
lx = x[emb_idx :: self.network.num_sub_prompts] lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
x = self.org_forward(x) x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx x[emb_idx :: self.network.num_sub_prompts] += lx
@@ -314,7 +317,7 @@ class LoRAInfModule(LoRAModule):
return x return x
def to_out_forward(self, x): def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
if self.network.is_last_network: if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts masks = [None] * self.network.num_sub_prompts
@@ -332,7 +335,7 @@ class LoRAInfModule(LoRAModule):
) )
self.network.shared[self.lora_name] = (lx, masks) self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
@@ -351,7 +354,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond: if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
# if num_sub_prompts > num of LoRAs, fill with zero # if num_sub_prompts > num of LoRAs, fill with zero
for i in range(len(masks)): for i in range(len(masks)):
if masks[i] is None: if masks[i] is None:
@@ -374,7 +377,7 @@ class LoRAInfModule(LoRAModule):
x1 = x1 + lx1 x1 = x1 + lx1
out[self.network.batch_size + i] = x1 out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond) # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
return out return out
@@ -511,7 +514,7 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else: else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
block_dims = [network_dim] * num_total_blocks block_dims = [network_dim] * num_total_blocks
if block_alphas is not None: if block_alphas is not None:
@@ -520,7 +523,7 @@ def get_block_dims_and_alphas(
len(block_alphas) == num_total_blocks len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else: else:
print( logger.warning(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
) )
block_alphas = [network_alpha] * num_total_blocks block_alphas = [network_alpha] * num_total_blocks
@@ -540,13 +543,13 @@ def get_block_dims_and_alphas(
else: else:
if conv_alpha is None: if conv_alpha is None:
conv_alpha = 1.0 conv_alpha = 1.0
print( logger.warning(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
) )
conv_block_alphas = [conv_alpha] * num_total_blocks conv_block_alphas = [conv_alpha] * num_total_blocks
else: else:
if conv_dim is not None: if conv_dim is not None:
print( logger.warning(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
) )
conv_block_dims = [conv_dim] * num_total_blocks conv_block_dims = [conv_dim] * num_total_blocks
@@ -586,7 +589,7 @@ def get_block_lr_weight(
elif name == "zeros": elif name == "zeros":
return [0.0 + base_lr] * max_len return [0.0 + base_lr] * max_len
else: else:
print( logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name) % (name)
) )
@@ -598,14 +601,14 @@ def get_block_lr_weight(
up_lr_weight = get_list(up_lr_weight) up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len] up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len: if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
@@ -613,24 +616,24 @@ def get_block_lr_weight(
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。") logger.info("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None: if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else: else:
print("down_lr_weight: all 1.0, すべて1.0") logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None: if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", mid_lr_weight) logger.info(f"mid_lr_weight: {mid_lr_weight}")
else: else:
print("mid_lr_weight: 1.0") logger.info("mid_lr_weight: 1.0")
if up_lr_weight != None: if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else: else:
print("up_lr_weight: all 1.0, すべて1.0") logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight return down_lr_weight, mid_lr_weight, up_lr_weight
@@ -711,7 +714,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key: elif "lora_down" in key:
dim = value.size()[0] dim = value.size()[0]
modules_dim[lora_name] = dim modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim) # logger.info(lora_name, value.size(), dim)
# support old LoRA without alpha # support old LoRA without alpha
for key in modules_dim.keys(): for key in modules_dim.keys():
@@ -786,20 +789,20 @@ class LoRANetwork(torch.nn.Module):
self.module_dropout = module_dropout self.module_dropout = module_dropout
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") logger.info(f"create LoRA network from weights")
elif block_dims is not None: elif block_dims is not None:
print(f"create LoRA network from block_dims") logger.info(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}") logger.info(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}") logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None: if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}") logger.info(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}") logger.info(f"conv_block_alphas: {conv_block_alphas}")
else: else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
if self.conv_lora_dim is not None: if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances # create module instances
def create_modules( def create_modules(
@@ -884,15 +887,15 @@ class LoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders): for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1: if len(text_encoders) > 1:
index = i + 1 index = i + 1
print(f"create LoRA for Text Encoder {index}:") logger.info(f"create LoRA for Text Encoder {index}:")
else: else:
index = None index = None
print(f"create LoRA for Text Encoder:") logger.info(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras) self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -900,15 +903,15 @@ class LoRANetwork(torch.nn.Module):
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0: if varbose and len(skipped) > 0:
print( logger.warning(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
) )
for name in skipped: for name in skipped:
print(f"\t{name}") logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None self.down_lr_weight: List[float] = None
@@ -939,12 +942,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder: if apply_text_encoder:
print("enable LoRA for text encoder") logger.info("enable LoRA for text encoder")
else: else:
self.text_encoder_loras = [] self.text_encoder_loras = []
if apply_unet: if apply_unet:
print("enable LoRA for U-Net") logger.info("enable LoRA for U-Net")
else: else:
self.unet_loras = [] self.unet_loras = []
@@ -966,12 +969,12 @@ class LoRANetwork(torch.nn.Module):
apply_unet = True apply_unet = True
if apply_text_encoder: if apply_text_encoder:
print("enable LoRA for text encoder") logger.info("enable LoRA for text encoder")
else: else:
self.text_encoder_loras = [] self.text_encoder_loras = []
if apply_unet: if apply_unet:
print("enable LoRA for U-Net") logger.info("enable LoRA for U-Net")
else: else:
self.unet_loras = [] self.unet_loras = []
@@ -982,7 +985,7 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device) lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged") logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight( def set_block_lr_weight(
@@ -1128,7 +1131,7 @@ class LoRANetwork(torch.nn.Module):
device = ref_weight.device device = ref_weight.device
def resize_add(mh, mw): def resize_add(mh, mw):
# print(mh, mw, mh * mw) # logger.info(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype) m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m mask_dic[mh * mw] = m

View File

@@ -10,7 +10,10 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from transformers import CLIPTextModel from transformers import CLIPTextModel
import torch import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def make_unet_conversion_map() -> Dict[str, str]: def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = [] unet_conversion_map_layer = []
@@ -248,7 +251,7 @@ def create_network_from_weights(
elif "lora_down" in key: elif "lora_down" in key:
dim = value.size()[0] dim = value.size()[0]
modules_dim[lora_name] = dim modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim) # logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha # support old LoRA without alpha
for key in modules_dim.keys(): for key in modules_dim.keys():
@@ -291,12 +294,12 @@ class LoRANetwork(torch.nn.Module):
super().__init__() super().__init__()
self.multiplier = multiplier self.multiplier = multiplier
print(f"create LoRA network from weights") logger.info("create LoRA network from weights")
# convert SDXL Stability AI's U-Net modules to Diffusers # convert SDXL Stability AI's U-Net modules to Diffusers
converted = self.convert_unet_modules(modules_dim, modules_alpha) converted = self.convert_unet_modules(modules_dim, modules_alpha)
if converted: if converted:
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)") logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
# create module instances # create module instances
def create_modules( def create_modules(
@@ -331,7 +334,7 @@ class LoRANetwork(torch.nn.Module):
lora_name = lora_name.replace(".", "_") lora_name = lora_name.replace(".", "_")
if lora_name not in modules_dim: if lora_name not in modules_dim:
# print(f"skipped {lora_name} (not found in modules_dim)") # logger.info(f"skipped {lora_name} (not found in modules_dim)")
skipped.append(lora_name) skipped.append(lora_name)
continue continue
@@ -362,18 +365,18 @@ class LoRANetwork(torch.nn.Module):
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras) self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
if len(skipped_te) > 0: if len(skipped_te) > 0:
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.") logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
# extend U-Net target modules to include Conv2d 3x3 # extend U-Net target modules to include Conv2d 3x3
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras: List[LoRAModule] self.unet_loras: List[LoRAModule]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
if len(skipped_un) > 0: if len(skipped_un) > 0:
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.") logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
# assertion # assertion
names = set() names = set()
@@ -420,11 +423,11 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True): def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder: if apply_text_encoder:
print("enable LoRA for text encoder") logger.info("enable LoRA for text encoder")
for lora in self.text_encoder_loras: for lora in self.text_encoder_loras:
lora.apply_to(multiplier) lora.apply_to(multiplier)
if apply_unet: if apply_unet:
print("enable LoRA for U-Net") logger.info("enable LoRA for U-Net")
for lora in self.unet_loras: for lora in self.unet_loras:
lora.apply_to(multiplier) lora.apply_to(multiplier)
@@ -433,16 +436,16 @@ class LoRANetwork(torch.nn.Module):
lora.unapply_to() lora.unapply_to()
def merge_to(self, multiplier=1.0): def merge_to(self, multiplier=1.0):
print("merge LoRA weights to original weights") logger.info("merge LoRA weights to original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras): for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.merge_to(multiplier) lora.merge_to(multiplier)
print(f"weights are merged") logger.info(f"weights are merged")
def restore_from(self, multiplier=1.0): def restore_from(self, multiplier=1.0):
print("restore LoRA weights from original weights") logger.info("restore LoRA weights from original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras): for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.restore_from(multiplier) lora.restore_from(multiplier)
print(f"weights are restored") logger.info(f"weights are restored")
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
# convert SDXL Stability AI's state dict to Diffusers' based state dict # convert SDXL Stability AI's state dict to Diffusers' based state dict
@@ -463,7 +466,7 @@ class LoRANetwork(torch.nn.Module):
my_state_dict = self.state_dict() my_state_dict = self.state_dict()
for key in state_dict.keys(): for key in state_dict.keys():
if state_dict[key].size() != my_state_dict[key].size(): if state_dict[key].size() != my_state_dict[key].size():
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}") # logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
state_dict[key] = state_dict[key].view(my_state_dict[key].size()) state_dict[key] = state_dict[key].view(my_state_dict[key].size())
return super().load_state_dict(state_dict, strict) return super().load_state_dict(state_dict, strict)
@@ -490,7 +493,7 @@ if __name__ == "__main__":
image_prefix = args.model_id.replace("/", "_") + "_" image_prefix = args.model_id.replace("/", "_") + "_"
# load Diffusers model # load Diffusers model
print(f"load model from {args.model_id}") logger.info(f"load model from {args.model_id}")
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline] pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
if args.sdxl: if args.sdxl:
# use_safetensors=True does not work with 0.18.2 # use_safetensors=True does not work with 0.18.2
@@ -503,7 +506,7 @@ if __name__ == "__main__":
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder] text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
# load LoRA weights # load LoRA weights
print(f"load LoRA weights from {args.lora_weights}") logger.info(f"load LoRA weights from {args.lora_weights}")
if os.path.splitext(args.lora_weights)[1] == ".safetensors": if os.path.splitext(args.lora_weights)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
@@ -512,10 +515,10 @@ if __name__ == "__main__":
lora_sd = torch.load(args.lora_weights) lora_sd = torch.load(args.lora_weights)
# create by LoRA weights and load weights # create by LoRA weights and load weights
print(f"create LoRA network") logger.info(f"create LoRA network")
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0) lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
print(f"load LoRA network weights") logger.info(f"load LoRA network weights")
lora_network.load_state_dict(lora_sd) lora_network.load_state_dict(lora_sd)
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
@@ -544,34 +547,34 @@ if __name__ == "__main__":
random.seed(seed) random.seed(seed)
# create image with original weights # create image with original weights
print(f"create image with original weights") logger.info(f"create image with original weights")
seed_everything(args.seed) seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "original.png") image.save(image_prefix + "original.png")
# apply LoRA network to the model: slower than merge_to, but can be reverted easily # apply LoRA network to the model: slower than merge_to, but can be reverted easily
print(f"apply LoRA network to the model") logger.info(f"apply LoRA network to the model")
lora_network.apply_to(multiplier=1.0) lora_network.apply_to(multiplier=1.0)
print(f"create image with applied LoRA") logger.info(f"create image with applied LoRA")
seed_everything(args.seed) seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "applied_lora.png") image.save(image_prefix + "applied_lora.png")
# unapply LoRA network to the model # unapply LoRA network to the model
print(f"unapply LoRA network to the model") logger.info(f"unapply LoRA network to the model")
lora_network.unapply_to() lora_network.unapply_to()
print(f"create image with unapplied LoRA") logger.info(f"create image with unapplied LoRA")
seed_everything(args.seed) seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unapplied_lora.png") image.save(image_prefix + "unapplied_lora.png")
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to) # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
print(f"merge LoRA network to the model") logger.info(f"merge LoRA network to the model")
lora_network.merge_to(multiplier=1.0) lora_network.merge_to(multiplier=1.0)
print(f"create image with LoRA") logger.info(f"create image with LoRA")
seed_everything(args.seed) seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "merged_lora.png") image.save(image_prefix + "merged_lora.png")
@@ -579,31 +582,31 @@ if __name__ == "__main__":
# restore (unmerge) LoRA weights: numerically unstable # restore (unmerge) LoRA weights: numerically unstable
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
# 保存したstate_dictから元の重みを復元するのが確実 # 保存したstate_dictから元の重みを復元するのが確実
print(f"restore (unmerge) LoRA weights") logger.info(f"restore (unmerge) LoRA weights")
lora_network.restore_from(multiplier=1.0) lora_network.restore_from(multiplier=1.0)
print(f"create image without LoRA") logger.info(f"create image without LoRA")
seed_everything(args.seed) seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unmerged_lora.png") image.save(image_prefix + "unmerged_lora.png")
# restore original weights # restore original weights
print(f"restore original weights") logger.info(f"restore original weights")
pipe.unet.load_state_dict(org_unet_sd) pipe.unet.load_state_dict(org_unet_sd)
pipe.text_encoder.load_state_dict(org_text_encoder_sd) pipe.text_encoder.load_state_dict(org_text_encoder_sd)
if args.sdxl: if args.sdxl:
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd) pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
print(f"create image with restored original weights") logger.info(f"create image with restored original weights")
seed_everything(args.seed) seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "restore_original.png") image.save(image_prefix + "restore_original.png")
# use convenience function to merge LoRA weights # use convenience function to merge LoRA weights
print(f"merge LoRA weights with convenience function") logger.info(f"merge LoRA weights with convenience function")
merge_lora_weights(pipe, lora_sd, multiplier=1.0) merge_lora_weights(pipe, lora_sd, multiplier=1.0)
print(f"create image with merged LoRA weights") logger.info(f"create image with merged LoRA weights")
seed_everything(args.seed) seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0] image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "convenience_merged_lora.png") image.save(image_prefix + "convenience_merged_lora.png")

View File

@@ -14,7 +14,10 @@ from transformers import CLIPTextModel
import numpy as np import numpy as np
import torch import torch
import re import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -49,7 +52,7 @@ class LoRAModule(torch.nn.Module):
# if limit_rank: # if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim) # self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim: # if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else: # else:
self.lora_dim = lora_dim self.lora_dim = lora_dim
@@ -197,7 +200,7 @@ class LoRAInfModule(LoRAModule):
else: else:
# conv2d 3x3 # conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding) # logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale weight = weight + self.multiplier * conved * self.scale
# set weight to org_module # set weight to org_module
@@ -236,7 +239,7 @@ class LoRAInfModule(LoRAModule):
self.region_mask = None self.region_mask = None
def default_forward(self, x): def default_forward(self, x):
# print("default_forward", self.lora_name, x.size()) # logger.info("default_forward", self.lora_name, x.size())
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x): def forward(self, x):
@@ -278,7 +281,7 @@ class LoRAInfModule(LoRAModule):
# apply mask for LoRA result # apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx) mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) # logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
lx = lx * mask lx = lx * mask
x = self.org_forward(x) x = self.org_forward(x)
@@ -307,7 +310,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond: if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :] query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts) # logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
return query return query
def sub_prompt_forward(self, x): def sub_prompt_forward(self, x):
@@ -322,7 +325,7 @@ class LoRAInfModule(LoRAModule):
lx = x[emb_idx :: self.network.num_sub_prompts] lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx) # logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
x = self.org_forward(x) x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx x[emb_idx :: self.network.num_sub_prompts] += lx
@@ -330,7 +333,7 @@ class LoRAInfModule(LoRAModule):
return x return x
def to_out_forward(self, x): def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network) # logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
if self.network.is_last_network: if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts masks = [None] * self.network.num_sub_prompts
@@ -348,7 +351,7 @@ class LoRAInfModule(LoRAModule):
) )
self.network.shared[self.lora_name] = (lx, masks) self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts) # logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1 lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1) masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
@@ -367,7 +370,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond: if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts) # logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# for i in range(len(masks)): # for i in range(len(masks)):
# if masks[i] is None: # if masks[i] is None:
# masks[i] = torch.zeros_like(masks[-1]) # masks[i] = torch.zeros_like(masks[-1])
@@ -389,7 +392,7 @@ class LoRAInfModule(LoRAModule):
x1 = x1 + lx1 x1 = x1 + lx1
out[self.network.batch_size + i] = x1 out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond) # logger.info("to_out_forward", x.size(), out.size(), has_real_uncond)
return out return out
@@ -526,7 +529,7 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else: else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
block_dims = [network_dim] * num_total_blocks block_dims = [network_dim] * num_total_blocks
if block_alphas is not None: if block_alphas is not None:
@@ -535,7 +538,7 @@ def get_block_dims_and_alphas(
len(block_alphas) == num_total_blocks len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください" ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else: else:
print( logger.warning(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります" f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
) )
block_alphas = [network_alpha] * num_total_blocks block_alphas = [network_alpha] * num_total_blocks
@@ -555,13 +558,13 @@ def get_block_dims_and_alphas(
else: else:
if conv_alpha is None: if conv_alpha is None:
conv_alpha = 1.0 conv_alpha = 1.0
print( logger.warning(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります" f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
) )
conv_block_alphas = [conv_alpha] * num_total_blocks conv_block_alphas = [conv_alpha] * num_total_blocks
else: else:
if conv_dim is not None: if conv_dim is not None:
print( logger.warning(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります" f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
) )
conv_block_dims = [conv_dim] * num_total_blocks conv_block_dims = [conv_dim] * num_total_blocks
@@ -601,7 +604,7 @@ def get_block_lr_weight(
elif name == "zeros": elif name == "zeros":
return [0.0 + base_lr] * max_len return [0.0 + base_lr] * max_len
else: else:
print( logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name) % (name)
) )
@@ -613,14 +616,14 @@ def get_block_lr_weight(
up_lr_weight = get_list(up_lr_weight) up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len] up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len] down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len: if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
@@ -628,24 +631,24 @@ def get_block_lr_weight(
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。") logger.info("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None: if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight) logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else: else:
print("down_lr_weight: all 1.0, すべて1.0") logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None: if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", mid_lr_weight) logger.info(f"mid_lr_weight: {mid_lr_weight}")
else: else:
print("mid_lr_weight: 1.0") logger.info("mid_lr_weight: 1.0")
if up_lr_weight != None: if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight) logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else: else:
print("up_lr_weight: all 1.0, すべて1.0") logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight return down_lr_weight, mid_lr_weight, up_lr_weight
@@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key: elif "lora_down" in key:
dim = value.size()[0] dim = value.size()[0]
modules_dim[lora_name] = dim modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim) # logger.info(lora_name, value.size(), dim)
# support old LoRA without alpha # support old LoRA without alpha
for key in modules_dim.keys(): for key in modules_dim.keys():
@@ -801,20 +804,20 @@ class LoRANetwork(torch.nn.Module):
self.module_dropout = module_dropout self.module_dropout = module_dropout
if modules_dim is not None: if modules_dim is not None:
print(f"create LoRA network from weights") logger.info(f"create LoRA network from weights")
elif block_dims is not None: elif block_dims is not None:
print(f"create LoRA network from block_dims") logger.info(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}") logger.info(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}") logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None: if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}") logger.info(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}") logger.info(f"conv_block_alphas: {conv_block_alphas}")
else: else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
if self.conv_lora_dim is not None: if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances # create module instances
def create_modules( def create_modules(
@@ -899,15 +902,15 @@ class LoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders): for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1: if len(text_encoders) > 1:
index = i + 1 index = i + 1
print(f"create LoRA for Text Encoder {index}:") logger.info(f"create LoRA for Text Encoder {index}:")
else: else:
index = None index = None
print(f"create LoRA for Text Encoder:") logger.info(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras) self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -915,15 +918,15 @@ class LoRANetwork(torch.nn.Module):
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0: if varbose and len(skipped) > 0:
print( logger.warning(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
) )
for name in skipped: for name in skipped:
print(f"\t{name}") logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None self.down_lr_weight: List[float] = None
@@ -954,12 +957,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder: if apply_text_encoder:
print("enable LoRA for text encoder") logger.info("enable LoRA for text encoder")
else: else:
self.text_encoder_loras = [] self.text_encoder_loras = []
if apply_unet: if apply_unet:
print("enable LoRA for U-Net") logger.info("enable LoRA for U-Net")
else: else:
self.unet_loras = [] self.unet_loras = []
@@ -981,12 +984,12 @@ class LoRANetwork(torch.nn.Module):
apply_unet = True apply_unet = True
if apply_text_encoder: if apply_text_encoder:
print("enable LoRA for text encoder") logger.info("enable LoRA for text encoder")
else: else:
self.text_encoder_loras = [] self.text_encoder_loras = []
if apply_unet: if apply_unet:
print("enable LoRA for U-Net") logger.info("enable LoRA for U-Net")
else: else:
self.unet_loras = [] self.unet_loras = []
@@ -997,7 +1000,7 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device) lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged") logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight( def set_block_lr_weight(
@@ -1144,7 +1147,7 @@ class LoRANetwork(torch.nn.Module):
device = ref_weight.device device = ref_weight.device
def resize_add(mh, mw): def resize_add(mh, mw):
# print(mh, mw, mh * mw) # logger.info(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16 m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype) m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m mask_dic[mh * mw] = m

View File

@@ -9,6 +9,10 @@ import torch
import library.model_util as model_util import library.model_util as model_util
import lora import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER_PATH = "openai/clip-vit-large-patch14" TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
@@ -20,12 +24,12 @@ def interrogate(args):
weights_dtype = torch.float16 weights_dtype = torch.float16
# いろいろ準備する # いろいろ準備する
print(f"loading SD model: {args.sd_model}") logger.info(f"loading SD model: {args.sd_model}")
args.pretrained_model_name_or_path = args.sd_model args.pretrained_model_name_or_path = args.sd_model
args.vae = None args.vae = None
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
print(f"loading LoRA: {args.model}") logger.info(f"loading LoRA: {args.model}")
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい # text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
@@ -35,11 +39,11 @@ def interrogate(args):
has_te_weight = True has_te_weight = True
break break
if not has_te_weight: if not has_te_weight:
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません") logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
return return
del vae del vae
print("loading tokenizer") logger.info("loading tokenizer")
if args.v2: if args.v2:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else: else:
@@ -53,7 +57,7 @@ def interrogate(args):
# トークンをひとつひとつ当たっていく # トークンをひとつひとつ当たっていく
token_id_start = 0 token_id_start = 0
token_id_end = max(tokenizer.all_special_ids) token_id_end = max(tokenizer.all_special_ids)
print(f"interrogate tokens are: {token_id_start} to {token_id_end}") logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
def get_all_embeddings(text_encoder): def get_all_embeddings(text_encoder):
embs = [] embs = []
@@ -79,24 +83,24 @@ def interrogate(args):
embs.extend(encoder_hidden_states) embs.extend(encoder_hidden_states)
return torch.stack(embs) return torch.stack(embs)
print("get original text encoder embeddings.") logger.info("get original text encoder embeddings.")
orig_embs = get_all_embeddings(text_encoder) orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0) network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
info = network.load_state_dict(weights_sd, strict=False) info = network.load_state_dict(weights_sd, strict=False)
print(f"Loading LoRA weights: {info}") logger.info(f"Loading LoRA weights: {info}")
network.to(DEVICE, dtype=weights_dtype) network.to(DEVICE, dtype=weights_dtype)
network.eval() network.eval()
del unet del unet
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません") logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.") logger.info("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder) lora_embs = get_all_embeddings(text_encoder)
# 比べる:とりあえず単純に差分の絶対値で # 比べる:とりあえず単純に差分の絶対値で
print("comparing...") logger.info("comparing...")
diffs = {} diffs = {}
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))): for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
diff = torch.mean(torch.abs(orig_emb - lora_emb)) diff = torch.mean(torch.abs(orig_emb - lora_emb))

View File

@@ -7,7 +7,10 @@ from safetensors.torch import load_file, save_file
from library import sai_model_spec, train_util from library import sai_model_spec, train_util
import library.model_util as model_util import library.model_util as model_util
import lora import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors": if os.path.splitext(file_name)[1] == ".safetensors":
@@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
name_to_module[lora_name] = child_module name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype) lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...") logger.info(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if "lora_down" in key: if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
@@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
# find original module for this lora # find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module: if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}") logger.info(f"no module found for LoRA weight: {key}")
continue continue
module = name_to_module[module_name] module = name_to_module[module_name]
# print(f"apply {key} to {module}") # logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key] down_weight = lora_sd[key]
up_weight = lora_sd[up_key] up_weight = lora_sd[up_key]
@@ -104,7 +107,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
else: else:
# conv2d 3x3 # conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding) # logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight) module.weight = torch.nn.Parameter(weight)
@@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
v2 = None v2 = None
base_model = None base_model = None
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype) lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None: if lora_metadata is not None:
@@ -151,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
if lora_module_name not in base_alphas: if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge # merge
print(f"merging...") logger.info(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if "alpha" in key: if "alpha" in key:
continue continue
@@ -196,8 +199,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm] merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model") logger.info("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same # check all dims are same
dims_list = list(set(base_dims.values())) dims_list = list(set(base_dims.values()))
@@ -239,7 +242,7 @@ def merge(args):
save_dtype = merge_dtype save_dtype = merge_dtype
if args.sd_model is not None: if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}") logger.info(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
@@ -264,18 +267,18 @@ def merge(args):
) )
if args.v2: if args.v2:
# TODO read sai modelspec # TODO read sai modelspec
print( logger.warning(
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" "Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
) )
print(f"saving SD model to: {args.save_to}") logger.info(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint( model_util.save_stable_diffusion_checkpoint(
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
) )
else: else:
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...") logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash metadata["sshs_model_hash"] = model_hash
@@ -289,12 +292,12 @@ def merge(args):
) )
if v2: if v2:
# TODO read sai modelspec # TODO read sai modelspec
print( logger.warning(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
) )
metadata.update(sai_metadata) metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}") logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)

View File

@@ -6,7 +6,10 @@ import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
import library.model_util as model_util import library.model_util as model_util
import lora import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors': if os.path.splitext(file_name)[1] == '.safetensors':
@@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
name_to_module[lora_name] = child_module name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype) lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...") logger.info(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if "lora_down" in key: if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
# find original module for this lora # find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module: if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}") logger.info(f"no module found for LoRA weight: {key}")
continue continue
module = name_to_module[module_name] module = name_to_module[module_name]
# print(f"apply {key} to {module}") # logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key] down_weight = lora_sd[key]
up_weight = lora_sd[up_key] up_weight = lora_sd[up_key]
@@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype):
alpha = None alpha = None
dim = None dim = None
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype) lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...") logger.info(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key: if 'alpha' in key:
if key in merged_sd: if key in merged_sd:
@@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype):
dim = lora_sd[key].size()[0] dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}") logger.info(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None: if alpha is None:
alpha = dim alpha = dim
@@ -142,19 +145,21 @@ def merge(args):
save_dtype = merge_dtype save_dtype = merge_dtype
if args.sd_model is not None: if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}") logger.info(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"\nsaving SD model to: {args.save_to}") logger.info("")
logger.info(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae) args.sd_model, 0, 0, save_dtype, vae)
else: else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"\nsaving model to: {args.save_to}") logger.info(f"")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)

View File

@@ -8,7 +8,10 @@ from transformers import CLIPTextModel
import numpy as np import numpy as np
import torch import torch
import re import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -237,7 +240,7 @@ class OFTNetwork(torch.nn.Module):
self.dim = dim self.dim = dim
self.alpha = alpha self.alpha = alpha
print( logger.info(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
) )
@@ -258,7 +261,7 @@ class OFTNetwork(torch.nn.Module):
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
oft_name = prefix + "." + name + "." + child_name oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_") oft_name = oft_name.replace(".", "_")
# print(oft_name) # logger.info(oft_name)
oft = module_class( oft = module_class(
oft_name, oft_name,
@@ -279,7 +282,7 @@ class OFTNetwork(torch.nn.Module):
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
# assertion # assertion
names = set() names = set()
@@ -316,7 +319,7 @@ class OFTNetwork(torch.nn.Module):
# TODO refactor to common function with apply_to # TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device): def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
print("enable OFT for U-Net") logger.info("enable OFT for U-Net")
for oft in self.unet_ofts: for oft in self.unet_ofts:
sd_for_lora = {} sd_for_lora = {}
@@ -326,7 +329,7 @@ class OFTNetwork(torch.nn.Module):
oft.load_state_dict(sd_for_lora, False) oft.load_state_dict(sd_for_lora, False)
oft.merge_to() oft.merge_to()
print(f"weights are merged") logger.info(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
@@ -338,11 +341,11 @@ class OFTNetwork(torch.nn.Module):
for oft in ofts: for oft in ofts:
params.extend(oft.parameters()) params.extend(oft.parameters())
# print num of params # logger.info num of params
num_params = 0 num_params = 0
for p in params: for p in params:
num_params += p.numel() num_params += p.numel()
print(f"OFT params: {num_params}") logger.info(f"OFT params: {num_params}")
return params return params
param_data = {"params": enumerate_params(self.unet_ofts)} param_data = {"params": enumerate_params(self.unet_ofts)}

View File

@@ -8,6 +8,10 @@ from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm from tqdm import tqdm
from library import train_util, model_util from library import train_util, model_util
import numpy as np import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6 MIN_SV = 1e-6
@@ -206,7 +210,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
scale = network_alpha/network_dim scale = network_alpha/network_dim
if dynamic_method: if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}") logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
lora_down_weight = None lora_down_weight = None
lora_up_weight = None lora_up_weight = None
@@ -275,10 +279,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
del param_dict del param_dict
if verbose: if verbose:
print(verbose_str) logger.info(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
print("resizing complete") logger.info("resizing complete")
return o_lora_sd, network_dim, new_alpha return o_lora_sd, network_dim, new_alpha
@@ -304,10 +308,10 @@ def resize(args):
if save_dtype is None: if save_dtype is None:
save_dtype = merge_dtype save_dtype = merge_dtype
print("loading Model...") logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype) lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print("Resizing Lora...") logger.info("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose) state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
# update metadata # update metadata
@@ -329,7 +333,7 @@ def resize(args):
metadata["sshs_model_hash"] = model_hash metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash metadata["sshs_legacy_hash"] = legacy_hash
print(f"saving model to: {args.save_to}") logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)

View File

@@ -8,7 +8,10 @@ from tqdm import tqdm
from library import sai_model_spec, sdxl_model_util, train_util from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util import library.model_util as model_util
import lora import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors": if os.path.splitext(file_name)[1] == ".safetensors":
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
name_to_module[lora_name] = child_module name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype) lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...") logger.info(f"merging...")
for key in tqdm(lora_sd.keys()): for key in tqdm(lora_sd.keys()):
if "lora_down" in key: if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
@@ -78,10 +81,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
# find original module for this lora # find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module: if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}") logger.info(f"no module found for LoRA weight: {key}")
continue continue
module = name_to_module[module_name] module = name_to_module[module_name]
# print(f"apply {key} to {module}") # logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key] down_weight = lora_sd[key]
up_weight = lora_sd[up_key] up_weight = lora_sd[up_key]
@@ -92,7 +95,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
# W <- W + U * D # W <- W + U * D
weight = module.weight weight = module.weight
# print(module_name, down_weight.size(), up_weight.size()) # logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2: if len(weight.size()) == 2:
# linear # linear
weight = weight + ratio * (up_weight @ down_weight) * scale weight = weight + ratio * (up_weight @ down_weight) * scale
@@ -107,7 +110,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
else: else:
# conv2d 3x3 # conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding) # logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight) module.weight = torch.nn.Parameter(weight)
@@ -121,7 +124,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
v2 = None v2 = None
base_model = None base_model = None
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype) lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None: if lora_metadata is not None:
@@ -154,10 +157,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
if lora_module_name not in base_alphas: if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge # merge
print(f"merging...") logger.info(f"merging...")
for key in tqdm(lora_sd.keys()): for key in tqdm(lora_sd.keys()):
if "alpha" in key: if "alpha" in key:
continue continue
@@ -200,8 +203,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
merged_sd[key_down] = merged_sd[key_down][perm] merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm] merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model") logger.info("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same # check all dims are same
dims_list = list(set(base_dims.values())) dims_list = list(set(base_dims.values()))
@@ -243,7 +246,7 @@ def merge(args):
save_dtype = merge_dtype save_dtype = merge_dtype
if args.sd_model is not None: if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}") logger.info(f"loading SD model: {args.sd_model}")
( (
text_model1, text_model1,
@@ -265,14 +268,14 @@ def merge(args):
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
) )
print(f"saving SD model to: {args.save_to}") logger.info(f"saving SD model to: {args.save_to}")
sdxl_model_util.save_stable_diffusion_checkpoint( sdxl_model_util.save_stable_diffusion_checkpoint(
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
) )
else: else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...") logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash metadata["sshs_model_hash"] = model_hash
@@ -286,7 +289,7 @@ def merge(args):
) )
metadata.update(sai_metadata) metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}") logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)

View File

@@ -5,7 +5,12 @@ import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm import tqdm from tqdm import tqdm
from library import sai_model_spec, train_util from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99
@@ -38,12 +43,12 @@ def save_to_file(file_name, state_dict, dtype, metadata):
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype): def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}") logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {} merged_sd = {}
v2 = None v2 = None
base_model = None base_model = None
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype) lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None: if lora_metadata is not None:
@@ -53,7 +58,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None) base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
# merge # merge
print(f"merging...") logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())): for key in tqdm(list(lora_sd.keys())):
if "lora_down" not in key: if "lora_down" not in key:
continue continue
@@ -70,7 +75,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
out_dim = up_weight.size()[0] out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4 conv2d = len(down_weight.size()) == 4
kernel_size = None if not conv2d else down_weight.size()[2:4] kernel_size = None if not conv2d else down_weight.size()[2:4]
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size) # logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
# make original weight if not exist # make original weight if not exist
if lora_module_name not in merged_sd: if lora_module_name not in merged_sd:
@@ -107,7 +112,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
merged_sd[lora_module_name] = weight merged_sd[lora_module_name] = weight
# extract from merged weights # extract from merged weights
print("extract new lora...") logger.info("extract new lora...")
merged_lora_sd = {} merged_lora_sd = {}
with torch.no_grad(): with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())): for lora_module_name, mat in tqdm(list(merged_sd.items())):
@@ -185,7 +190,7 @@ def merge(args):
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
) )
print(f"calculating hashes and creating metadata...") logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash metadata["sshs_model_hash"] = model_hash
@@ -200,12 +205,12 @@ def merge(args):
) )
if v2: if v2:
# TODO read sai modelspec # TODO read sai modelspec
print( logger.warning(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します" "Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
) )
metadata.update(sai_metadata) metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}") logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata) save_to_file(args.save_to, state_dict, save_dtype, metadata)

View File

@@ -29,5 +29,7 @@ huggingface-hub==0.20.1
# protobuf==3.20.3 # protobuf==3.20.3
# open clip for SDXL # open clip for SDXL
open-clip-torch==2.20.0 open-clip-torch==2.20.0
# For logging
rich==13.7.0
# for kohya_ss library # for kohya_ss library
-e . -e .

View File

@@ -55,6 +55,10 @@ from networks.lora import LoRANetwork
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite from networks.control_net_lllite import ControlNetLLLite
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: # scheduler:
SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_START = 0.00085
@@ -76,12 +80,12 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn: if mem_eff_attn:
print("Enable memory efficient attention for U-Net") logger.info("Enable memory efficient attention for U-Net")
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
unet.set_use_memory_efficient_attention(False, True) unet.set_use_memory_efficient_attention(False, True)
elif xformers: elif xformers:
print("Enable xformers for U-Net") logger.info("Enable xformers for U-Net")
try: try:
import xformers.ops import xformers.ops
except ImportError: except ImportError:
@@ -89,7 +93,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
unet.set_use_memory_efficient_attention(True, False) unet.set_use_memory_efficient_attention(True, False)
elif sdpa: elif sdpa:
print("Enable SDPA for U-Net") logger.info("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False) unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True) unet.set_use_sdpa(True)
@@ -106,7 +110,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
def replace_vae_attn_to_memory_efficient(): def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states, **kwargs): def forward_flash_attn(self, hidden_states, **kwargs):
@@ -162,7 +166,7 @@ def replace_vae_attn_to_memory_efficient():
def replace_vae_attn_to_xformers(): def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers") logger.info("VAE: Attention.forward has been replaced to xformers")
import xformers.ops import xformers.ops
def forward_xformers(self, hidden_states, **kwargs): def forward_xformers(self, hidden_states, **kwargs):
@@ -218,7 +222,7 @@ def replace_vae_attn_to_xformers():
def replace_vae_attn_to_sdpa(): def replace_vae_attn_to_sdpa():
print("VAE: Attention.forward has been replaced to sdpa") logger.info("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs): def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states residual = hidden_states
@@ -352,7 +356,7 @@ class PipelineLike:
token_replacements = self.token_replacements_list[tokenizer_index] token_replacements = self.token_replacements_list[tokenizer_index]
def replace_tokens(tokens): def replace_tokens(tokens):
# print("replace_tokens", tokens, "=>", token_replacements) # logger.info("replace_tokens", tokens, "=>", token_replacements)
if isinstance(tokens, torch.Tensor): if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist() tokens = tokens.tolist()
@@ -444,7 +448,7 @@ class PipelineLike:
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None: if not do_classifier_free_guidance and negative_scale is not None:
print(f"negative_scale is ignored if guidance scalle <= 1.0") logger.info(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None negative_scale = None
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
@@ -548,7 +552,7 @@ class PipelineLike:
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
if init_image is not None and self.clip_vision_model is not None: if init_image is not None and self.clip_vision_model is not None:
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
@@ -715,7 +719,7 @@ class PipelineLike:
if not enabled or ratio >= 1.0: if not enabled or ratio >= 1.0:
continue continue
if ratio < i / len(timesteps): if ratio < i / len(timesteps):
print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})") logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None) control_net.set_cond_image(None)
each_control_net_enabled[j] = False each_control_net_enabled[j] = False
@@ -935,7 +939,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
if word.strip() == "BREAK": if word.strip() == "BREAK":
# pad until next multiple of tokenizer's max token length # pad until next multiple of tokenizer's max token length
pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length)
print(f"BREAK pad_len: {pad_len}") logger.info(f"BREAK pad_len: {pad_len}")
for i in range(pad_len): for i in range(pad_len):
# v2のときEOSをつけるべきかどうかわからないぜ # v2のときEOSをつけるべきかどうかわからないぜ
# if i == 0: # if i == 0:
@@ -965,7 +969,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
tokens.append(text_token) tokens.append(text_token)
weights.append(text_weight) weights.append(text_weight)
if truncated: if truncated:
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights return tokens, weights
@@ -1238,7 +1242,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
elif len(count_range) == 2: elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])] count_range = [int(count_range[0]), int(count_range[1])]
else: else:
print(f"invalid count range: {count_range}") logger.warning(f"invalid count range: {count_range}")
count_range = [1, 1] count_range = [1, 1]
if count_range[0] > count_range[1]: if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]] count_range = [count_range[1], count_range[0]]
@@ -1308,7 +1312,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
# def load_clip_l14_336(dtype): # def load_clip_l14_336(dtype):
# print(f"loading CLIP: {CLIP_ID_L14_336}") # logger.info(f"loading CLIP: {CLIP_ID_L14_336}")
# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) # text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
# return text_encoder # return text_encoder
@@ -1378,7 +1382,7 @@ def main(args):
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
# tokenizerを読み込む # tokenizerを読み込む
print("loading tokenizer") logger.info("loading tokenizer")
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
# schedulerを用意する # schedulerを用意する
@@ -1452,7 +1456,7 @@ def main(args):
self.sampler_noises = noises self.sampler_noises = noises
def randn(self, shape, device=None, dtype=None, layout=None, generator=None): def randn(self, shape, device=None, dtype=None, layout=None, generator=None):
# print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) # logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index)
if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises):
noise = self.sampler_noises[self.sampler_noise_index] noise = self.sampler_noises[self.sampler_noise_index]
if shape != noise.shape: if shape != noise.shape:
@@ -1461,7 +1465,7 @@ def main(args):
noise = None noise = None
if noise == None: if noise == None:
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
self.sampler_noise_index += 1 self.sampler_noise_index += 1
@@ -1493,7 +1497,7 @@ def main(args):
# ↓以下は結局PipeでFalseに設定されるので意味がなかった # ↓以下は結局PipeでFalseに設定されるので意味がなかった
# # clip_sample=Trueにする # # clip_sample=Trueにする
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True") # logger.info("set clip_sample to True")
# scheduler.config.clip_sample = True # scheduler.config.clip_sample = True
# deviceを決定する # deviceを決定する
@@ -1522,7 +1526,7 @@ def main(args):
vae_dtype = dtype vae_dtype = dtype
if args.no_half_vae: if args.no_half_vae:
print("set vae_dtype to float32") logger.info("set vae_dtype to float32")
vae_dtype = torch.float32 vae_dtype = torch.float32
vae.to(vae_dtype).to(device) vae.to(vae_dtype).to(device)
vae.eval() vae.eval()
@@ -1547,10 +1551,10 @@ def main(args):
network_merge = args.network_merge_n_models network_merge = args.network_merge_n_models
else: else:
network_merge = 0 network_merge = 0
print(f"network_merge: {network_merge}") logger.info(f"network_merge: {network_merge}")
for i, network_module in enumerate(args.network_module): for i, network_module in enumerate(args.network_module):
print("import network module:", network_module) logger.info(f"import network module: {network_module}")
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
@@ -1568,7 +1572,7 @@ def main(args):
raise ValueError("No weight. Weight is required.") raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i] network_weight = args.network_weights[i]
print("load network weights from:", network_weight) logger.info(f"load network weights from: {network_weight}")
if model_util.is_safetensors(network_weight) and args.network_show_meta: if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open from safetensors.torch import safe_open
@@ -1576,7 +1580,7 @@ def main(args):
with safe_open(network_weight, framework="pt") as f: with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
if metadata is not None: if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}") logger.info(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights( network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
@@ -1586,20 +1590,20 @@ def main(args):
mergeable = network.is_mergeable() mergeable = network.is_mergeable()
if network_merge and not mergeable: if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.") logger.warning("network is not mergiable. ignore merge option.")
if not mergeable or i >= network_merge: if not mergeable or i >= network_merge:
# not merging # not merging
network.apply_to([text_encoder1, text_encoder2], unet) network.apply_to([text_encoder1, text_encoder2], unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}") logger.info(f"weights are loaded: {info}")
if args.opt_channels_last: if args.opt_channels_last:
network.to(memory_format=torch.channels_last) network.to(memory_format=torch.channels_last)
network.to(dtype).to(device) network.to(dtype).to(device)
if network_pre_calc: if network_pre_calc:
print("backup original weights") logger.info("backup original weights")
network.backup_weights() network.backup_weights()
networks.append(network) networks.append(network)
@@ -1613,7 +1617,7 @@ def main(args):
# upscalerの指定があれば取得する # upscalerの指定があれば取得する
upscaler = None upscaler = None
if args.highres_fix_upscaler: if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler) logger.info(f"import upscaler module: {args.highres_fix_upscaler}")
imported_module = importlib.import_module(args.highres_fix_upscaler) imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {} us_kwargs = {}
@@ -1622,7 +1626,7 @@ def main(args):
key, value = net_arg.split("=") key, value = net_arg.split("=")
us_kwargs[key] = value us_kwargs[key] = value
print("create upscaler") logger.info("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs) upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device) upscaler.to(dtype).to(device)
@@ -1639,7 +1643,7 @@ def main(args):
# control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
if args.control_net_lllite_models: if args.control_net_lllite_models:
for i, model_file in enumerate(args.control_net_lllite_models): for i, model_file in enumerate(args.control_net_lllite_models):
print(f"loading ControlNet-LLLite: {model_file}") logger.info(f"loading ControlNet-LLLite: {model_file}")
from safetensors.torch import load_file from safetensors.torch import load_file
@@ -1670,7 +1674,7 @@ def main(args):
control_nets.append((control_net, ratio)) control_nets.append((control_net, ratio))
if args.opt_channels_last: if args.opt_channels_last:
print(f"set optimizing: channels last") logger.info(f"set optimizing: channels last")
text_encoder1.to(memory_format=torch.channels_last) text_encoder1.to(memory_format=torch.channels_last)
text_encoder2.to(memory_format=torch.channels_last) text_encoder2.to(memory_format=torch.channels_last)
vae.to(memory_format=torch.channels_last) vae.to(memory_format=torch.channels_last)
@@ -1694,7 +1698,7 @@ def main(args):
args.clip_skip, args.clip_skip,
) )
pipe.set_control_nets(control_nets) pipe.set_control_nets(control_nets)
print("pipeline is ready.") logger.info("pipeline is ready.")
if args.diffusers_xformers: if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention() pipe.enable_xformers_memory_efficient_attention()
@@ -1736,7 +1740,7 @@ def main(args):
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
assert ( assert (
min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1
), f"token ids1 is not ordered" ), f"token ids1 is not ordered"
@@ -1766,7 +1770,7 @@ def main(args):
# promptを取得する # promptを取得する
if args.from_file is not None: if args.from_file is not None:
print(f"reading prompts from {args.from_file}") logger.info(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f: with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines() prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0] prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
@@ -1795,7 +1799,7 @@ def main(args):
for p in paths: for p in paths:
image = Image.open(p) image = Image.open(p)
if image.mode != "RGB": if image.mode != "RGB":
print(f"convert image to RGB from {image.mode}: {p}") logger.info(f"convert image to RGB from {image.mode}: {p}")
image = image.convert("RGB") image = image.convert("RGB")
images.append(image) images.append(image)
@@ -1811,14 +1815,14 @@ def main(args):
return resized return resized
if args.image_path is not None: if args.image_path is not None:
print(f"load image for img2img: {args.image_path}") logger.info(f"load image for img2img: {args.image_path}")
init_images = load_images(args.image_path) init_images = load_images(args.image_path)
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
print(f"loaded {len(init_images)} images for img2img") logger.info(f"loaded {len(init_images)} images for img2img")
# CLIP Vision # CLIP Vision
if args.clip_vision_strength is not None: if args.clip_vision_strength is not None:
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
vision_model.to(device, dtype) vision_model.to(device, dtype)
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
@@ -1826,22 +1830,22 @@ def main(args):
pipe.clip_vision_model = vision_model pipe.clip_vision_model = vision_model
pipe.clip_vision_processor = processor pipe.clip_vision_processor = processor
pipe.clip_vision_strength = args.clip_vision_strength pipe.clip_vision_strength = args.clip_vision_strength
print(f"CLIP Vision model loaded.") logger.info(f"CLIP Vision model loaded.")
else: else:
init_images = None init_images = None
if args.mask_path is not None: if args.mask_path is not None:
print(f"load mask for inpainting: {args.mask_path}") logger.info(f"load mask for inpainting: {args.mask_path}")
mask_images = load_images(args.mask_path) mask_images = load_images(args.mask_path)
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
print(f"loaded {len(mask_images)} mask images for inpainting") logger.info(f"loaded {len(mask_images)} mask images for inpainting")
else: else:
mask_images = None mask_images = None
# promptがないとき、画像のPngInfoから取得する # promptがないとき、画像のPngInfoから取得する
if init_images is not None and len(prompt_list) == 0 and not args.interactive: if init_images is not None and len(prompt_list) == 0 and not args.interactive:
print("get prompts from images' metadata") logger.info("get prompts from images' metadata")
for img in init_images: for img in init_images:
if "prompt" in img.text: if "prompt" in img.text:
prompt = img.text["prompt"] prompt = img.text["prompt"]
@@ -1870,17 +1874,17 @@ def main(args):
h = int(h * args.highres_fix_scale + 0.5) h = int(h * args.highres_fix_scale + 0.5)
if init_images is not None: if init_images is not None:
print(f"resize img2img source images to {w}*{h}") logger.info(f"resize img2img source images to {w}*{h}")
init_images = resize_images(init_images, (w, h)) init_images = resize_images(init_images, (w, h))
if mask_images is not None: if mask_images is not None:
print(f"resize img2img mask images to {w}*{h}") logger.info(f"resize img2img mask images to {w}*{h}")
mask_images = resize_images(mask_images, (w, h)) mask_images = resize_images(mask_images, (w, h))
regional_network = False regional_network = False
if networks and mask_images: if networks and mask_images:
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True regional_network = True
print("use mask as region") logger.info("use mask as region")
size = None size = None
for i, network in enumerate(networks): for i, network in enumerate(networks):
@@ -1905,14 +1909,14 @@ def main(args):
prev_image = None # for VGG16 guided prev_image = None # for VGG16 guided
if args.guide_image_path is not None: if args.guide_image_path is not None:
print(f"load image for ControlNet guidance: {args.guide_image_path}") logger.info(f"load image for ControlNet guidance: {args.guide_image_path}")
guide_images = [] guide_images = []
for p in args.guide_image_path: for p in args.guide_image_path:
guide_images.extend(load_images(p)) guide_images.extend(load_images(p))
print(f"loaded {len(guide_images)} guide images for guidance") logger.info(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0: if len(guide_images) == 0:
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") logger.warning(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
guide_images = None guide_images = None
else: else:
guide_images = None guide_images = None
@@ -1938,7 +1942,7 @@ def main(args):
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for gen_iter in range(args.n_iter): for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}") logger.info(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7FFFFFFF) iter_seed = random.randint(0, 0x7FFFFFFF)
# バッチ処理の関数 # バッチ処理の関数
@@ -1950,7 +1954,7 @@ def main(args):
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す # 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
print("process 1st stage") logger.info("process 1st stage")
batch_1st = [] batch_1st = []
for _, base, ext in batch: for _, base, ext in batch:
@@ -1995,7 +1999,7 @@ def main(args):
images_1st = process_batch(batch_1st, True, True) images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する # 2nd stageのバッチを作成して以下処理する
print("process 2nd stage") logger.info("process 2nd stage")
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler: if upscaler:
@@ -2161,7 +2165,7 @@ def main(args):
n.restore_weights() n.restore_weights()
for n in networks: for n in networks:
n.pre_calculation() n.pre_calculation()
print("pre-calculation... done") logger.info("pre-calculation... done")
images = pipe( images = pipe(
prompts, prompts,
@@ -2240,7 +2244,7 @@ def main(args):
cv2.waitKey() cv2.waitKey()
cv2.destroyAllWindows() cv2.destroyAllWindows()
except ImportError: except ImportError:
print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") logger.error("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
return images return images
@@ -2253,7 +2257,8 @@ def main(args):
# interactive # interactive
valid = False valid = False
while not valid: while not valid:
print("\nType prompt:") logger.info("")
logger.info("Type prompt:")
try: try:
raw_prompt = input() raw_prompt = input()
except EOFError: except EOFError:
@@ -2302,74 +2307,74 @@ def main(args):
prompt_args = raw_prompt.strip().split(" --") prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0] prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
for parg in prompt_args[1:]: for parg in prompt_args[1:]:
try: try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE) m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m: if m:
width = int(m.group(1)) width = int(m.group(1))
print(f"width: {width}") logger.info(f"width: {width}")
continue continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE) m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m: if m:
height = int(m.group(1)) height = int(m.group(1))
print(f"height: {height}") logger.info(f"height: {height}")
continue continue
m = re.match(r"ow (\d+)", parg, re.IGNORECASE) m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
if m: if m:
original_width = int(m.group(1)) original_width = int(m.group(1))
print(f"original width: {original_width}") logger.info(f"original width: {original_width}")
continue continue
m = re.match(r"oh (\d+)", parg, re.IGNORECASE) m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
if m: if m:
original_height = int(m.group(1)) original_height = int(m.group(1))
print(f"original height: {original_height}") logger.info(f"original height: {original_height}")
continue continue
m = re.match(r"nw (\d+)", parg, re.IGNORECASE) m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
if m: if m:
original_width_negative = int(m.group(1)) original_width_negative = int(m.group(1))
print(f"original width negative: {original_width_negative}") logger.info(f"original width negative: {original_width_negative}")
continue continue
m = re.match(r"nh (\d+)", parg, re.IGNORECASE) m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
if m: if m:
original_height_negative = int(m.group(1)) original_height_negative = int(m.group(1))
print(f"original height negative: {original_height_negative}") logger.info(f"original height negative: {original_height_negative}")
continue continue
m = re.match(r"ct (\d+)", parg, re.IGNORECASE) m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
if m: if m:
crop_top = int(m.group(1)) crop_top = int(m.group(1))
print(f"crop top: {crop_top}") logger.info(f"crop top: {crop_top}")
continue continue
m = re.match(r"cl (\d+)", parg, re.IGNORECASE) m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
if m: if m:
crop_left = int(m.group(1)) crop_left = int(m.group(1))
print(f"crop left: {crop_left}") logger.info(f"crop left: {crop_left}")
continue continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE) m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps if m: # steps
steps = max(1, min(1000, int(m.group(1)))) steps = max(1, min(1000, int(m.group(1))))
print(f"steps: {steps}") logger.info(f"steps: {steps}")
continue continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed if m: # seed
seeds = [int(d) for d in m.group(1).split(",")] seeds = [int(d) for d in m.group(1).split(",")]
print(f"seeds: {seeds}") logger.info(f"seeds: {seeds}")
continue continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale if m: # scale
scale = float(m.group(1)) scale = float(m.group(1))
print(f"scale: {scale}") logger.info(f"scale: {scale}")
continue continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
@@ -2378,25 +2383,25 @@ def main(args):
negative_scale = None negative_scale = None
else: else:
negative_scale = float(m.group(1)) negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}") logger.info(f"negative scale: {negative_scale}")
continue continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength if m: # strength
strength = float(m.group(1)) strength = float(m.group(1))
print(f"strength: {strength}") logger.info(f"strength: {strength}")
continue continue
m = re.match(r"n (.+)", parg, re.IGNORECASE) m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt if m: # negative prompt
negative_prompt = m.group(1) negative_prompt = m.group(1)
print(f"negative prompt: {negative_prompt}") logger.info(f"negative prompt: {negative_prompt}")
continue continue
m = re.match(r"c (.+)", parg, re.IGNORECASE) m = re.match(r"c (.+)", parg, re.IGNORECASE)
if m: # clip prompt if m: # clip prompt
clip_prompt = m.group(1) clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}") logger.info(f"clip prompt: {clip_prompt}")
continue continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
@@ -2404,47 +2409,47 @@ def main(args):
network_muls = [float(v) for v in m.group(1).split(",")] network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks): while len(network_muls) < len(networks):
network_muls.append(network_muls[-1]) network_muls.append(network_muls[-1])
print(f"network mul: {network_muls}") logger.info(f"network mul: {network_muls}")
continue continue
# Deep Shrink # Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1 if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1)) ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}") logger.info(f"deep shrink depth 1: {ds_depth_1}")
continue continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1 if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1)) ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}") logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2 if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1)) ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}") logger.info(f"deep shrink depth 2: {ds_depth_2}")
continue continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2 if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1)) ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}") logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio if m: # deep shrink ratio
ds_ratio = float(m.group(1)) ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}") logger.info(f"deep shrink ratio: {ds_ratio}")
continue continue
except ValueError as ex: except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}") logger.error(f"Exception in parsing / 解析エラー: {parg}")
print(ex) logger.error(f"{ex}")
# override Deep Shrink # override Deep Shrink
if ds_depth_1 is not None: if ds_depth_1 is not None:
@@ -2462,7 +2467,7 @@ def main(args):
if len(predefined_seeds) > 0: if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0) seed = predefined_seeds.pop(0)
else: else:
print("predefined seeds are exhausted") logger.error("predefined seeds are exhausted")
seed = None seed = None
elif args.iter_same_seed: elif args.iter_same_seed:
seeds = iter_seed seeds = iter_seed
@@ -2472,7 +2477,7 @@ def main(args):
if seed is None: if seed is None:
seed = random.randint(0, 0x7FFFFFFF) seed = random.randint(0, 0x7FFFFFFF)
if args.interactive: if args.interactive:
print(f"seed: {seed}") logger.info(f"seed: {seed}")
# prepare init image, guide image and mask # prepare init image, guide image and mask
init_image = mask_image = guide_image = None init_image = mask_image = guide_image = None
@@ -2488,7 +2493,7 @@ def main(args):
width = width - width % 32 width = width - width % 32
height = height - height % 32 height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]: if width != init_image.size[0] or height != init_image.size[1]:
print( logger.warning(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
) )
@@ -2548,7 +2553,7 @@ def main(args):
process_batch(batch_data, highres_fix) process_batch(batch_data, highres_fix)
batch_data.clear() batch_data.clear()
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -23,6 +23,10 @@ from safetensors.torch import load_file
from library import model_util, sdxl_model_util from library import model_util, sdxl_model_util
import networks.lora as lora import networks.lora as lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: このあたりの設定はSD1/2と同じでいいらしい # scheduler: このあたりの設定はSD1/2と同じでいいらしい
# scheduler: The settings around here seem to be the same as SD1/2 # scheduler: The settings around here seem to be the same as SD1/2
@@ -140,7 +144,7 @@ if __name__ == "__main__":
vae_dtype = DTYPE vae_dtype = DTYPE
if DTYPE == torch.float16: if DTYPE == torch.float16:
print("use float32 for vae") logger.info("use float32 for vae")
vae_dtype = torch.float32 vae_dtype = torch.float32
vae.to(DEVICE, dtype=vae_dtype) vae.to(DEVICE, dtype=vae_dtype)
vae.eval() vae.eval()
@@ -187,7 +191,7 @@ if __name__ == "__main__":
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256) emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# print("emb1", emb1.shape) # logger.info("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE) c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
@@ -217,7 +221,7 @@ if __name__ == "__main__":
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2] text_embedding2_penu = enc_out["hidden_states"][-2]
# print("hidden_states2", text_embedding2_penu.shape) # logger.info("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish # 連結して終了 concat and finish
@@ -226,7 +230,7 @@ if __name__ == "__main__":
# cond # cond
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2) c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape) # logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1) c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
# uncond # uncond
@@ -323,4 +327,4 @@ if __name__ == "__main__":
seed = int(seed) seed = int(seed)
generate_image(prompt, prompt2, negative_prompt, seed) generate_image(prompt, prompt2, negative_prompt, seed)
print("Done!") logger.info("Done!")

View File

@@ -20,6 +20,10 @@ from diffusers import DDPMScheduler
from library import sdxl_model_util from library import sdxl_model_util
import library.train_util as train_util import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util import library.config_util as config_util
import library.sdxl_train_util as sdxl_train_util import library.sdxl_train_util as sdxl_train_util
from library.config_util import ( from library.config_util import (
@@ -117,18 +121,18 @@ def train(args):
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"] ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
) )
else: else:
if use_dreambooth_method: if use_dreambooth_method:
print("Using DreamBooth method.") logger.info("Using DreamBooth method.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -139,7 +143,7 @@ def train(args):
] ]
} }
else: else:
print("Training with captions.") logger.info("Training with captions.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -169,7 +173,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, True) train_util.debug_dataset(train_dataset_group, True)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print( logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
) )
return return
@@ -185,7 +189,7 @@ def train(args):
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
@@ -537,7 +541,7 @@ def train(args):
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified") # logger.info("text encoder outputs verified")
# get size embeddings # get size embeddings
orig_size = batch["original_sizes_hw"] orig_size = batch["original_sizes_hw"]
@@ -724,7 +728,7 @@ def train(args):
logit_scale, logit_scale,
ckpt_info, ckpt_info,
) )
print("model saved.") logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -45,7 +45,10 @@ from library.custom_train_functions import (
apply_debiased_estimation, apply_debiased_estimation,
) )
import networks.control_net_lllite_for_train as control_net_lllite_for_train import networks.control_net_lllite_for_train as control_net_lllite_for_train
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
@@ -78,11 +81,11 @@ def train(args):
# データセットを準備する # データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config: if use_user_config:
print(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"] ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
@@ -114,7 +117,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print( logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります" "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
) )
return return
@@ -124,7 +127,7 @@ def train(args):
train_dataset_group.is_latent_cacheable() train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else: else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
assert ( assert (
@@ -132,7 +135,7 @@ def train(args):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
@@ -231,8 +234,8 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(unet.prepare_params()) trainable_params = list(unet.prepare_params())
print(f"trainable params count: {len(trainable_params)}") logger.info(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -324,7 +327,7 @@ def train(args):
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -548,7 +551,7 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True) save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.") logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -41,7 +41,10 @@ from library.custom_train_functions import (
apply_debiased_estimation, apply_debiased_estimation,
) )
import networks.control_net_lllite as control_net_lllite import networks.control_net_lllite as control_net_lllite
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
@@ -74,11 +77,11 @@ def train(args):
# データセットを準備する # データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config: if use_user_config:
print(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"] ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
@@ -110,7 +113,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print( logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります" "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
) )
return return
@@ -120,7 +123,7 @@ def train(args):
train_dataset_group.is_latent_cacheable() train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else: else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
assert ( assert (
@@ -128,7 +131,7 @@ def train(args):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
@@ -199,8 +202,8 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(network.prepare_optimizer_params()) trainable_params = list(network.prepare_optimizer_params())
print(f"trainable params count: {len(trainable_params)}") logger.info(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -297,7 +300,7 @@ def train(args):
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -516,7 +519,7 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.") logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -7,7 +7,10 @@ init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util from library import sdxl_model_util, sdxl_train_util, train_util
import train_network import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class SdxlNetworkTrainer(train_network.NetworkTrainer): class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self): def __init__(self):
@@ -60,7 +63,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
if not args.lowram: if not args.lowram:
# メモリ消費を減らす # メモリ消費を減らす
print("move vae and unet to cpu to save memory") logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device org_vae_device = vae.device
org_unet_device = unet.device org_unet_device = unet.device
vae.to("cpu") vae.to("cpu")
@@ -85,7 +88,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
if not args.lowram: if not args.lowram:
print("move vae and unet back to original device") logger.info("move vae and unet back to original device")
vae.to(org_vae_device) vae.to(org_vae_device)
unet.to(org_unet_device) unet.to(org_unet_device)
else: else:
@@ -143,7 +146,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2 # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified") # logger.info("text encoder outputs verified")
return encoder_hidden_states1, encoder_hidden_states2, pool2 return encoder_hidden_states1, encoder_hidden_states2, pool2

View File

@@ -16,7 +16,10 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None: def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
@@ -41,18 +44,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"] ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
) )
else: else:
if use_dreambooth_method: if use_dreambooth_method:
print("Using DreamBooth method.") logger.info("Using DreamBooth method.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -63,7 +66,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
] ]
} }
else: else:
print("Training with captions.") logger.info("Training with captions.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -90,7 +93,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
@@ -98,7 +101,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む # モデルを読み込む
print("load model") logger.info("load model")
if args.sdxl: if args.sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else: else:
@@ -152,7 +155,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.skip_existing: if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
print(f"Skipping {image_info.latents_npz} because it already exists.") logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue continue
image_infos.append(image_info) image_infos.append(image_info)

View File

@@ -16,7 +16,10 @@ from library.config_util import (
ConfigSanitizer, ConfigSanitizer,
BlueprintGenerator, BlueprintGenerator,
) )
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None: def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
@@ -48,18 +51,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"] ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
) )
else: else:
if use_dreambooth_method: if use_dreambooth_method:
print("Using DreamBooth method.") logger.info("Using DreamBooth method.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -70,7 +73,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
] ]
} }
else: else:
print("Training with captions.") logger.info("Training with captions.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -95,14 +98,14 @@ def cache_to_disk(args: argparse.Namespace) -> None:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args) weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
print("load model") logger.info("load model")
if args.sdxl: if args.sdxl:
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
text_encoders = [text_encoder1, text_encoder2] text_encoders = [text_encoder1, text_encoder2]
@@ -147,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.skip_existing: if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz): if os.path.exists(image_info.text_encoder_outputs_npz):
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue continue
image_info.input_ids1 = input_ids1 image_info.input_ids1 = input_ids1

View File

@@ -1,6 +1,10 @@
import argparse import argparse
import cv2 import cv2
import logging
from library.utils import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
def canny(args): def canny(args):
img = cv2.imread(args.input) img = cv2.imread(args.input)
@@ -10,7 +14,7 @@ def canny(args):
# canny_img = 255 - canny_img # canny_img = 255 - canny_img
cv2.imwrite(args.output, canny_img) cv2.imwrite(args.output, canny_img)
print("done!") logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -6,7 +6,10 @@ import torch
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
import library.model_util as model_util import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert(args): def convert(args):
# 引数を確認する # 引数を確認する
@@ -30,7 +33,7 @@ def convert(args):
# モデルを読み込む # モデルを読み込む
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else "")) msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
print(f"loading {msg}: {args.model_to_load}") logger.info(f"loading {msg}: {args.model_to_load}")
if is_load_ckpt: if is_load_ckpt:
v2_model = args.v2 v2_model = args.v2
@@ -48,13 +51,13 @@ def convert(args):
if args.v1 == args.v2: if args.v1 == args.v2:
# 自動判定する # 自動判定する
v2_model = unet.config.cross_attention_dim == 1024 v2_model = unet.config.cross_attention_dim == 1024
print("checking model version: model is " + ("v2" if v2_model else "v1")) logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
else: else:
v2_model = not args.v1 v2_model = not args.v1
# 変換して保存する # 変換して保存する
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers" msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
print(f"converting and saving as {msg}: {args.model_to_save}") logger.info(f"converting and saving as {msg}: {args.model_to_save}")
if is_save_ckpt: if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None original_model = args.model_to_load if is_load_ckpt else None
@@ -70,15 +73,15 @@ def convert(args):
save_dtype=save_dtype, save_dtype=save_dtype,
vae=vae, vae=vae,
) )
print(f"model saved. total converted state_dict keys: {key_count}") logger.info(f"model saved. total converted state_dict keys: {key_count}")
else: else:
print( logger.info(
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}" f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
) )
model_util.save_diffusers_checkpoint( model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
) )
print("model saved.") logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -15,6 +15,10 @@ import os
from anime_face_detector import create_detector from anime_face_detector import create_detector
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
KP_REYE = 11 KP_REYE = 11
KP_LEYE = 19 KP_LEYE = 19
@@ -24,7 +28,7 @@ SCORE_THRES = 0.90
def detect_faces(detector, image, min_size): def detect_faces(detector, image, min_size):
preds = detector(image) # bgr preds = detector(image) # bgr
# print(len(preds)) # logger.info(len(preds))
faces = [] faces = []
for pred in preds: for pred in preds:
@@ -78,7 +82,7 @@ def process(args):
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません" assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
# アニメ顔検出モデルを読み込む # アニメ顔検出モデルを読み込む
print("loading face detector.") logger.info("loading face detector.")
detector = create_detector('yolov3') detector = create_detector('yolov3')
# cropの引数を解析する # cropの引数を解析する
@@ -97,7 +101,7 @@ def process(args):
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens] crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
# 画像を処理する # 画像を処理する
print("processing.") logger.info("processing.")
output_extension = ".png" output_extension = ".png"
os.makedirs(args.dst_dir, exist_ok=True) os.makedirs(args.dst_dir, exist_ok=True)
@@ -111,7 +115,7 @@ def process(args):
if len(image.shape) == 2: if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
if image.shape[2] == 4: if image.shape[2] == 4:
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}") logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
h, w = image.shape[:2] h, w = image.shape[:2]
@@ -144,11 +148,11 @@ def process(args):
# 顔サイズを基準にリサイズする # 顔サイズを基準にリサイズする
scale = args.resize_face_size / face_size scale = args.resize_face_size / face_size
if scale < cur_crop_width / w: if scale < cur_crop_width / w:
print( logger.warning(
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}") f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_width / w scale = cur_crop_width / w
if scale < cur_crop_height / h: if scale < cur_crop_height / h:
print( logger.warning(
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}") f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_height / h scale = cur_crop_height / h
elif crop_h_ratio is not None: elif crop_h_ratio is not None:
@@ -157,10 +161,10 @@ def process(args):
else: else:
# 切り出しサイズ指定あり # 切り出しサイズ指定あり
if w < cur_crop_width: if w < cur_crop_width:
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}") logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_width / w scale = cur_crop_width / w
if h < cur_crop_height: if h < cur_crop_height:
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}") logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_height / h scale = cur_crop_height / h
if args.resize_fit: if args.resize_fit:
scale = max(cur_crop_width / w, cur_crop_height / h) scale = max(cur_crop_width / w, cur_crop_height / h)
@@ -198,7 +202,7 @@ def process(args):
face_img = face_img[y:y + cur_crop_height] face_img = face_img[y:y + cur_crop_height]
# # debug # # debug
# print(path, cx, cy, angle) # logger.info(path, cx, cy, angle)
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8)) # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
# cv2.imshow("image", crp) # cv2.imshow("image", crp)
# if cv2.waitKey() == 27: # if cv2.waitKey() == 27:

View File

@@ -14,7 +14,10 @@ import torch
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
from PIL import Image from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
@@ -216,7 +219,7 @@ class Upscaler(nn.Module):
upsampled_images = upsampled_images / 127.5 - 1.0 upsampled_images = upsampled_images / 127.5 - 1.0
# convert upsample images to latents with batch size # convert upsample images to latents with batch size
# print("Encoding upsampled (LANCZOS4) images...") # logger.info("Encoding upsampled (LANCZOS4) images...")
upsampled_latents = [] upsampled_latents = []
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)): for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
batch = upsampled_images[i : i + vae_batch_size].to(vae.device) batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
@@ -227,7 +230,7 @@ class Upscaler(nn.Module):
upsampled_latents = torch.cat(upsampled_latents, dim=0) upsampled_latents = torch.cat(upsampled_latents, dim=0)
# upscale (refine) latents with this model with batch size # upscale (refine) latents with this model with batch size
print("Upscaling latents...") logger.info("Upscaling latents...")
upscaled_latents = [] upscaled_latents = []
for i in range(0, upsampled_latents.shape[0], batch_size): for i in range(0, upsampled_latents.shape[0], batch_size):
with torch.no_grad(): with torch.no_grad():
@@ -242,7 +245,7 @@ def create_upscaler(**kwargs):
weights = kwargs["weights"] weights = kwargs["weights"]
model = Upscaler() model = Upscaler()
print(f"Loading weights from {weights}...") logger.info(f"Loading weights from {weights}...")
if os.path.splitext(weights)[1] == ".safetensors": if os.path.splitext(weights)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
@@ -261,14 +264,14 @@ def upscale_images(args: argparse.Namespace):
# load VAE with Diffusers # load VAE with Diffusers
assert args.vae_path is not None, "VAE path is required" assert args.vae_path is not None, "VAE path is required"
print(f"Loading VAE from {args.vae_path}...") logger.info(f"Loading VAE from {args.vae_path}...")
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.to(DEVICE, dtype=us_dtype) vae.to(DEVICE, dtype=us_dtype)
# prepare model # prepare model
print("Preparing model...") logger.info("Preparing model...")
upscaler: Upscaler = create_upscaler(weights=args.weights) upscaler: Upscaler = create_upscaler(weights=args.weights)
# print("Loading weights from", args.weights) # logger.info("Loading weights from", args.weights)
# upscaler.load_state_dict(torch.load(args.weights)) # upscaler.load_state_dict(torch.load(args.weights))
upscaler.eval() upscaler.eval()
upscaler.to(DEVICE, dtype=us_dtype) upscaler.to(DEVICE, dtype=us_dtype)
@@ -303,14 +306,14 @@ def upscale_images(args: argparse.Namespace):
image_debug.save(dest_file_name) image_debug.save(dest_file_name)
# upscale # upscale
print("Upscaling...") logger.info("Upscaling...")
upscaled_latents = upscaler.upscale( upscaled_latents = upscaler.upscale(
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
) )
upscaled_latents /= 0.18215 upscaled_latents /= 0.18215
# decode with batch # decode with batch
print("Decoding...") logger.info("Decoding...")
upscaled_images = [] upscaled_images = []
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)): for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
with torch.no_grad(): with torch.no_grad():

View File

@@ -5,7 +5,10 @@ import torch
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm import tqdm from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def is_unet_key(key): def is_unet_key(key):
# VAE or TextEncoder, the last one is for SDXL # VAE or TextEncoder, the last one is for SDXL
@@ -45,10 +48,10 @@ def merge(args):
# check if all models are safetensors # check if all models are safetensors
for model in args.models: for model in args.models:
if not model.endswith("safetensors"): if not model.endswith("safetensors"):
print(f"Model {model} is not a safetensors model") logger.info(f"Model {model} is not a safetensors model")
exit() exit()
if not os.path.isfile(model): if not os.path.isfile(model):
print(f"Model {model} does not exist") logger.info(f"Model {model} does not exist")
exit() exit()
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
@@ -65,7 +68,7 @@ def merge(args):
if merged_sd is None: if merged_sd is None:
# load first model # load first model
print(f"Loading model {model}, ratio = {ratio}...") logger.info(f"Loading model {model}, ratio = {ratio}...")
merged_sd = {} merged_sd = {}
with safe_open(model, framework="pt", device=args.device) as f: with safe_open(model, framework="pt", device=args.device) as f:
for key in tqdm(f.keys()): for key in tqdm(f.keys()):
@@ -81,11 +84,11 @@ def merge(args):
value = ratio * value.to(dtype) # first model's value * ratio value = ratio * value.to(dtype) # first model's value * ratio
merged_sd[key] = value merged_sd[key] = value
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
continue continue
# load other models # load other models
print(f"Loading model {model}, ratio = {ratio}...") logger.info(f"Loading model {model}, ratio = {ratio}...")
with safe_open(model, framework="pt", device=args.device) as f: with safe_open(model, framework="pt", device=args.device) as f:
model_keys = f.keys() model_keys = f.keys()
@@ -93,7 +96,7 @@ def merge(args):
_, new_key = replace_text_encoder_key(key) _, new_key = replace_text_encoder_key(key)
if new_key not in merged_sd: if new_key not in merged_sd:
if args.show_skipped and new_key not in first_model_keys: if args.show_skipped and new_key not in first_model_keys:
print(f"Skip: {new_key}") logger.info(f"Skip: {new_key}")
continue continue
value = f.get_tensor(key) value = f.get_tensor(key)
@@ -104,7 +107,7 @@ def merge(args):
for key in merged_sd.keys(): for key in merged_sd.keys():
if key in model_keys: if key in model_keys:
continue continue
print(f"Key {key} not in model {model}, use first model's value") logger.warning(f"Key {key} not in model {model}, use first model's value")
if key in supplementary_key_ratios: if key in supplementary_key_ratios:
supplementary_key_ratios[key] += ratio supplementary_key_ratios[key] += ratio
else: else:
@@ -112,7 +115,7 @@ def merge(args):
# add supplementary keys' value (including VAE and TextEncoder) # add supplementary keys' value (including VAE and TextEncoder)
if len(supplementary_key_ratios) > 0: if len(supplementary_key_ratios) > 0:
print("add first model's value") logger.info("add first model's value")
with safe_open(args.models[0], framework="pt", device=args.device) as f: with safe_open(args.models[0], framework="pt", device=args.device) as f:
for key in tqdm(f.keys()): for key in tqdm(f.keys()):
_, new_key = replace_text_encoder_key(key) _, new_key = replace_text_encoder_key(key)
@@ -120,7 +123,7 @@ def merge(args):
continue continue
if is_unet_key(new_key): # not VAE or TextEncoder if is_unet_key(new_key): # not VAE or TextEncoder
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
value = f.get_tensor(key) # original key value = f.get_tensor(key) # original key
@@ -134,7 +137,7 @@ def merge(args):
if not output_file.endswith(".safetensors"): if not output_file.endswith(".safetensors"):
output_file = output_file + ".safetensors" output_file = output_file + ".safetensors"
print(f"Saving to {output_file}...") logger.info(f"Saving to {output_file}...")
# convert to save_dtype # convert to save_dtype
for k in merged_sd.keys(): for k in merged_sd.keys():
@@ -142,7 +145,7 @@ def merge(args):
save_file(merged_sd, output_file) save_file(merged_sd, output_file)
print("Done!") logger.info("Done!")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -7,7 +7,10 @@ from safetensors.torch import load_file
from library.original_unet import UNet2DConditionModel, SampleOutput from library.original_unet import UNet2DConditionModel, SampleOutput
import library.model_util as model_util import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ControlNetInfo(NamedTuple): class ControlNetInfo(NamedTuple):
unet: Any unet: Any
@@ -51,7 +54,7 @@ def load_control_net(v2, unet, model):
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む # state dictを読み込む
print(f"ControlNet: loading control SD model : {model}") logger.info(f"ControlNet: loading control SD model : {model}")
if model_util.is_safetensors(model): if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model) ctrl_sd_sd = load_file(model)
@@ -61,7 +64,7 @@ def load_control_net(v2, unet, model):
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference:", is_difference) logger.info(f"ControlNet: loading difference: {is_difference}")
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる # またTransfer Controlの元weightとなる
@@ -89,13 +92,13 @@ def load_control_net(v2, unet, model):
# ControlNetのU-Netを作成する # ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config) ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info) logger.info(f"ControlNet: loading Control U-Net: {info}")
# U-Net以外のControlNetを作成する # U-Net以外のControlNetを作成する
# TODO support middle only # TODO support middle only
ctrl_net = ControlNet() ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd) info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info) logger.info("ControlNet: loading ControlNet: {info}")
ctrl_unet.to(unet.device, dtype=unet.dtype) ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype) ctrl_net.to(unet.device, dtype=unet.dtype)
@@ -117,7 +120,7 @@ def load_preprocess(prep_type: str):
return canny return canny
print("Unsupported prep type:", prep_type) logger.info(f"Unsupported prep type: {prep_type}")
return None return None
@@ -174,7 +177,7 @@ def call_unet_and_control_net(
cnet_idx = step % cnet_cnt cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx] cnet_info = control_nets[cnet_idx]
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio: if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states) return original_unet(sample, timestep, encoder_hidden_states)
@@ -192,7 +195,7 @@ def call_unet_and_control_net(
# ControlNet # ControlNet
cnet_outs_list = [] cnet_outs_list = []
for i, cnet_info in enumerate(control_nets): for i, cnet_info in enumerate(control_nets):
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) # logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio: if cnet_info.ratio < current_ratio:
continue continue
guided_hint = guided_hints[i] guided_hint = guided_hints[i]
@@ -232,7 +235,7 @@ def unet_forward(
upsample_size = None upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.") logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True forward_upsample_size = True
# 1. time # 1. time

View File

@@ -6,7 +6,10 @@ import shutil
import math import math
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
# Split the max_resolution string by "," and strip any whitespaces # Split the max_resolution string by "," and strip any whitespaces
@@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
image.save(os.path.join(dst_img_folder, new_filename), quality=100) image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved" proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
# If other files with same basename, copy them with resolution suffix # If other files with same basename, copy them with resolution suffix
if copy_associated_files: if copy_associated_files:
@@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
continue continue
for max_resolution in max_resolutions: for max_resolution in max_resolutions:
new_asoc_file = base + '+' + max_resolution + ext new_asoc_file = base + '+' + max_resolution + ext
print(f"Copy {asoc_file} as {new_asoc_file}") logger.info(f"Copy {asoc_file} as {new_asoc_file}")
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file)) shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))

View File

@@ -1,6 +1,10 @@
import json import json
import argparse import argparse
from safetensors import safe_open from safetensors import safe_open
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True) parser.add_argument("--model", type=str, required=True)
@@ -10,10 +14,10 @@ with safe_open(args.model, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
if metadata is None: if metadata is None:
print("No metadata found") logger.error("No metadata found")
else: else:
# metadata is json dict, but not pretty printed # metadata is json dict, but not pretty printed
# sort by key and pretty print # sort by key and pretty print
print(json.dumps(metadata, indent=4, sort_keys=True)) print(json.dumps(metadata, indent=4, sort_keys=True))

View File

@@ -35,7 +35,10 @@ from library.custom_train_functions import (
pyramid_noise_like, pyramid_noise_like,
apply_noise_offset, apply_noise_offset,
) )
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する # TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
@@ -69,11 +72,11 @@ def train(args):
# データセットを準備する # データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config: if use_user_config:
print(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"] ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
@@ -103,7 +106,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print( logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります" "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
) )
return return
@@ -114,7 +117,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
@@ -310,7 +313,7 @@ def train(args):
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -567,7 +570,7 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, controlnet, force_sync_upload=True) save_model(ckpt_name, controlnet, force_sync_upload=True)
print("model saved.") logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -35,6 +35,10 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction, scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation, apply_debiased_estimation,
) )
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# perlin_noise, # perlin_noise,
@@ -54,11 +58,11 @@ def train(args):
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"] ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
@@ -93,13 +97,13 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
print( logger.warning(
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong" f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
) )
print( logger.warning(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
) )
@@ -449,7 +453,7 @@ def train(args):
train_util.save_sd_model_on_train_end( train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
) )
print("model saved.") logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -41,7 +41,10 @@ from library.custom_train_functions import (
add_v_prediction_like_loss, add_v_prediction_like_loss,
apply_debiased_estimation, apply_debiased_estimation,
) )
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class NetworkTrainer: class NetworkTrainer:
def __init__(self): def __init__(self):
@@ -153,18 +156,18 @@ class NetworkTrainer:
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if use_user_config: if use_user_config:
print(f"Loading dataset config from {args.dataset_config}") logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"] ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
) )
else: else:
if use_dreambooth_method: if use_dreambooth_method:
print("Using DreamBooth method.") logger.info("Using DreamBooth method.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -175,7 +178,7 @@ class NetworkTrainer:
] ]
} }
else: else:
print("Training with captions.") logger.info("Training with captions.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -204,7 +207,7 @@ class NetworkTrainer:
train_util.debug_dataset(train_dataset_group) train_util.debug_dataset(train_dataset_group)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print( logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります" "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
) )
return return
@@ -217,7 +220,7 @@ class NetworkTrainer:
self.assert_extra_args(args, train_dataset_group) self.assert_extra_args(args, train_dataset_group)
# acceleratorを準備する # acceleratorを準備する
print("preparing accelerator") logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
@@ -310,7 +313,7 @@ class NetworkTrainer:
if hasattr(network, "prepare_network"): if hasattr(network, "prepare_network"):
network.prepare_network(args) network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"): if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
print( logger.warning(
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません" "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
) )
args.scale_weight_norms = False args.scale_weight_norms = False
@@ -938,7 +941,7 @@ class NetworkTrainer:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.") logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -32,6 +32,10 @@ from library.custom_train_functions import (
add_v_prediction_like_loss, add_v_prediction_like_loss,
apply_debiased_estimation, apply_debiased_estimation,
) )
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
imagenet_templates_small = [ imagenet_templates_small = [
"a photo of a {}", "a photo of a {}",
@@ -178,7 +182,7 @@ class TextualInversionTrainer:
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list] tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
@@ -288,7 +292,7 @@ class TextualInversionTrainer:
] ]
} }
else: else:
print("Train with captions.") logger.info("Train with captions.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -736,7 +740,7 @@ class TextualInversionTrainer:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True) save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.") logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -36,6 +36,10 @@ from library.custom_train_functions import (
) )
import library.original_unet as original_unet import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
imagenet_templates_small = [ imagenet_templates_small = [
"a photo of a {}", "a photo of a {}",
@@ -99,7 +103,7 @@ def train(args):
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
print( logger.warning(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
) )
assert ( assert (
@@ -114,7 +118,7 @@ def train(args):
tokenizer = train_util.load_tokenizer(args) tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
@@ -127,7 +131,7 @@ def train(args):
if args.init_word is not None: if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False) init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token: if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print( logger.warning(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}" f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
) )
else: else:
@@ -141,7 +145,7 @@ def train(args):
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}" ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings) token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}") logger.info(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered" assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
@@ -169,7 +173,7 @@ def train(args):
tokenizer.add_tokens(token_strings_XTI) tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI) token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
print(f"tokens are added (XTI): {token_ids_XTI}") logger.info(f"tokens are added (XTI): {token_ids_XTI}")
# Resize the token embeddings as we are adding new special tokens to the tokenizer # Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer)) text_encoder.resize_token_embeddings(len(tokenizer))
@@ -178,7 +182,7 @@ def train(args):
if init_token_ids is not None: if init_token_ids is not None:
for i, token_id in enumerate(token_ids_XTI): for i, token_id in enumerate(token_ids_XTI):
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]] token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights # load weights
if args.weights is not None: if args.weights is not None:
@@ -186,22 +190,22 @@ def train(args):
assert len(token_ids) == len( assert len(token_ids) == len(
embeddings embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size()) # logger.info(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids_XTI, embeddings): for token_id, embedding in zip(token_ids_XTI, embeddings):
token_embeds[token_id] = embedding token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded") logger.info(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する # データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"] ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
print( logger.info(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored) ", ".join(ignored)
) )
@@ -209,14 +213,14 @@ def train(args):
else: else:
use_dreambooth_method = args.in_json is None use_dreambooth_method = args.in_json is None
if use_dreambooth_method: if use_dreambooth_method:
print("Use DreamBooth method.") logger.info("Use DreamBooth method.")
user_config = { user_config = {
"datasets": [ "datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)} {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
] ]
} }
else: else:
print("Train with captions.") logger.info("Train with captions.")
user_config = { user_config = {
"datasets": [ "datasets": [
{ {
@@ -240,7 +244,7 @@ def train(args):
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template: if use_template:
print(f"use template for training captions. is object: {args.use_object_template}") logger.info(f"use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings) replace_to = " ".join(token_strings)
captions = [] captions = []
@@ -264,7 +268,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, show_input_ids=True) train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return return
if len(train_dataset_group) == 0: if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return return
if cache_latents: if cache_latents:
@@ -297,7 +301,7 @@ def train(args):
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.") logger.info("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters() trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params) _, _, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -318,7 +322,7 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil( args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
) )
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") logger.info(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信 # データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps) train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -332,7 +336,7 @@ def train(args):
) )
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates)) # logger.info(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder # Freeze all parameters except for the token embeddings in text encoder
@@ -370,15 +374,15 @@ def train(args):
# 学習する # 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始") logger.info("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}") logger.info(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
@@ -403,7 +407,8 @@ def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name) ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}") logger.info("")
logger.info(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype) save_weights(ckpt_file, embs, save_dtype)
if args.huggingface_repo_id is not None: if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -411,12 +416,13 @@ def train(args):
def remove_model(old_ckpt_name): def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file): if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}") logger.info(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file) os.remove(old_ckpt_file)
# training loop # training loop
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}") logger.info("")
logger.info(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
text_encoder.train() text_encoder.train()
@@ -586,7 +592,7 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True) save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.") logger.info("model saved.")
def save_weights(file, updated_embs, save_dtype): def save_weights(file, updated_embs, save_dtype):