mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Replace print with logger if they are logs (#905)
* Add get_my_logger() * Use logger instead of print * Fix log level * Removed line-breaks for readability * Use setup_logging() * Add rich to requirements.txt * Make simple * Use logger instead of print --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
14
fine_tune.py
14
fine_tune.py
@@ -18,6 +18,10 @@ init_ipex()
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.config_util as config_util
|
import library.config_util as config_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
@@ -49,11 +53,11 @@ def train(args):
|
|||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "in_json"]
|
ignored = ["train_data_dir", "in_json"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
@@ -86,7 +90,7 @@ def train(args):
|
|||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
return
|
return
|
||||||
if len(train_dataset_group) == 0:
|
if len(train_dataset_group) == 0:
|
||||||
print(
|
logger.error(
|
||||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -97,7 +101,7 @@ def train(args):
|
|||||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
@@ -461,7 +465,7 @@ def train(args):
|
|||||||
train_util.save_sd_model_on_train_end(
|
train_util.save_sd_model_on_train_end(
|
||||||
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
||||||
)
|
)
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ import torch.nn.functional as F
|
|||||||
import os
|
import os
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from timm.models.hub import download_cached_file
|
from timm.models.hub import download_cached_file
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BLIP_Base(nn.Module):
|
class BLIP_Base(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -235,6 +239,6 @@ def load_checkpoint(model,url_or_filename):
|
|||||||
del state_dict[key]
|
del state_dict[key]
|
||||||
|
|
||||||
msg = model.load_state_dict(state_dict,strict=False)
|
msg = model.load_state_dict(state_dict,strict=False)
|
||||||
print('load checkpoint from %s'%url_or_filename)
|
logger.info('load checkpoint from %s'%url_or_filename)
|
||||||
return model,msg
|
return model,msg
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ import json
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
|
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
|
||||||
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
|
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
|
||||||
@@ -36,13 +40,13 @@ def clean_tags(image_key, tags):
|
|||||||
tokens = tags.split(", rating")
|
tokens = tags.split(", rating")
|
||||||
if len(tokens) == 1:
|
if len(tokens) == 1:
|
||||||
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
||||||
# print("no rating:")
|
# logger.info("no rating:")
|
||||||
# print(f"{image_key} {tags}")
|
# logger.info(f"{image_key} {tags}")
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if len(tokens) > 2:
|
if len(tokens) > 2:
|
||||||
print("multiple ratings:")
|
logger.info("multiple ratings:")
|
||||||
print(f"{image_key} {tags}")
|
logger.info(f"{image_key} {tags}")
|
||||||
tags = tokens[0]
|
tags = tokens[0]
|
||||||
|
|
||||||
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
|
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
|
||||||
@@ -124,43 +128,43 @@ def clean_caption(caption):
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
if os.path.exists(args.in_json):
|
if os.path.exists(args.in_json):
|
||||||
print(f"loading existing metadata: {args.in_json}")
|
logger.info(f"loading existing metadata: {args.in_json}")
|
||||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||||
metadata = json.load(f)
|
metadata = json.load(f)
|
||||||
else:
|
else:
|
||||||
print("no metadata / メタデータファイルがありません")
|
logger.error("no metadata / メタデータファイルがありません")
|
||||||
return
|
return
|
||||||
|
|
||||||
print("cleaning captions and tags.")
|
logger.info("cleaning captions and tags.")
|
||||||
image_keys = list(metadata.keys())
|
image_keys = list(metadata.keys())
|
||||||
for image_key in tqdm(image_keys):
|
for image_key in tqdm(image_keys):
|
||||||
tags = metadata[image_key].get('tags')
|
tags = metadata[image_key].get('tags')
|
||||||
if tags is None:
|
if tags is None:
|
||||||
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
||||||
else:
|
else:
|
||||||
org = tags
|
org = tags
|
||||||
tags = clean_tags(image_key, tags)
|
tags = clean_tags(image_key, tags)
|
||||||
metadata[image_key]['tags'] = tags
|
metadata[image_key]['tags'] = tags
|
||||||
if args.debug and org != tags:
|
if args.debug and org != tags:
|
||||||
print("FROM: " + org)
|
logger.info("FROM: " + org)
|
||||||
print("TO: " + tags)
|
logger.info("TO: " + tags)
|
||||||
|
|
||||||
caption = metadata[image_key].get('caption')
|
caption = metadata[image_key].get('caption')
|
||||||
if caption is None:
|
if caption is None:
|
||||||
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
||||||
else:
|
else:
|
||||||
org = caption
|
org = caption
|
||||||
caption = clean_caption(caption)
|
caption = clean_caption(caption)
|
||||||
metadata[image_key]['caption'] = caption
|
metadata[image_key]['caption'] = caption
|
||||||
if args.debug and org != caption:
|
if args.debug and org != caption:
|
||||||
print("FROM: " + org)
|
logger.info("FROM: " + org)
|
||||||
print("TO: " + caption)
|
logger.info("TO: " + caption)
|
||||||
|
|
||||||
# metadataを書き出して終わり
|
# metadataを書き出して終わり
|
||||||
print(f"writing metadata: {args.out_json}")
|
logger.info(f"writing metadata: {args.out_json}")
|
||||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||||
json.dump(metadata, f, indent=2)
|
json.dump(metadata, f, indent=2)
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
@@ -178,10 +182,10 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
if len(unknown) == 1:
|
if len(unknown) == 1:
|
||||||
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
||||||
print("All captions and tags in the metadata are processed.")
|
logger.warning("All captions and tags in the metadata are processed.")
|
||||||
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
||||||
print("メタデータ内のすべてのキャプションとタグが処理されます。")
|
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
|
||||||
args.in_json = args.out_json
|
args.in_json = args.out_json
|
||||||
args.out_json = unknown[0]
|
args.out_json = unknown[0]
|
||||||
elif len(unknown) > 0:
|
elif len(unknown) > 0:
|
||||||
|
|||||||
@@ -15,6 +15,10 @@ from torchvision.transforms.functional import InterpolationMode
|
|||||||
sys.path.append(os.path.dirname(__file__))
|
sys.path.append(os.path.dirname(__file__))
|
||||||
from blip.blip import blip_decoder, is_url
|
from blip.blip import blip_decoder, is_url
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
@@ -47,7 +51,7 @@ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
|||||||
# convert to tensor temporarily so dataloader will accept it
|
# convert to tensor temporarily so dataloader will accept it
|
||||||
tensor = IMAGE_TRANSFORM(image)
|
tensor = IMAGE_TRANSFORM(image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return (tensor, img_path)
|
return (tensor, img_path)
|
||||||
@@ -74,21 +78,21 @@ def main(args):
|
|||||||
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
||||||
|
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
print("Current Working Directory is: ", cwd)
|
logger.info(f"Current Working Directory is: {cwd}")
|
||||||
os.chdir("finetune")
|
os.chdir("finetune")
|
||||||
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
|
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
|
||||||
args.caption_weights = os.path.join("..", args.caption_weights)
|
args.caption_weights = os.path.join("..", args.caption_weights)
|
||||||
|
|
||||||
print(f"load images from {args.train_data_dir}")
|
logger.info(f"load images from {args.train_data_dir}")
|
||||||
train_data_dir_path = Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
logger.info(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
print(f"loading BLIP caption: {args.caption_weights}")
|
logger.info(f"loading BLIP caption: {args.caption_weights}")
|
||||||
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
|
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(DEVICE)
|
model = model.to(DEVICE)
|
||||||
print("BLIP loaded")
|
logger.info("BLIP loaded")
|
||||||
|
|
||||||
# captioningする
|
# captioningする
|
||||||
def run_batch(path_imgs):
|
def run_batch(path_imgs):
|
||||||
@@ -108,7 +112,7 @@ def main(args):
|
|||||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||||
f.write(caption + "\n")
|
f.write(caption + "\n")
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(image_path, caption)
|
logger.info(f'{image_path} {caption}')
|
||||||
|
|
||||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||||
if args.max_data_loader_n_workers is not None:
|
if args.max_data_loader_n_workers is not None:
|
||||||
@@ -138,7 +142,7 @@ def main(args):
|
|||||||
raw_image = raw_image.convert("RGB")
|
raw_image = raw_image.convert("RGB")
|
||||||
img_tensor = IMAGE_TRANSFORM(raw_image)
|
img_tensor = IMAGE_TRANSFORM(raw_image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
b_imgs.append((image_path, img_tensor))
|
b_imgs.append((image_path, img_tensor))
|
||||||
@@ -148,7 +152,7 @@ def main(args):
|
|||||||
if len(b_imgs) > 0:
|
if len(b_imgs) > 0:
|
||||||
run_batch(b_imgs)
|
run_batch(b_imgs)
|
||||||
|
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ from transformers import AutoProcessor, AutoModelForCausalLM
|
|||||||
from transformers.generation.utils import GenerationMixin
|
from transformers.generation.utils import GenerationMixin
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
@@ -35,8 +38,8 @@ def remove_words(captions, debug):
|
|||||||
for pat in PATTERN_REPLACE:
|
for pat in PATTERN_REPLACE:
|
||||||
cap = pat.sub("", cap)
|
cap = pat.sub("", cap)
|
||||||
if debug and cap != caption:
|
if debug and cap != caption:
|
||||||
print(caption)
|
logger.info(caption)
|
||||||
print(cap)
|
logger.info(cap)
|
||||||
removed_caps.append(cap)
|
removed_caps.append(cap)
|
||||||
return removed_caps
|
return removed_caps
|
||||||
|
|
||||||
@@ -70,16 +73,16 @@ def main(args):
|
|||||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(f"load images from {args.train_data_dir}")
|
logger.info(f"load images from {args.train_data_dir}")
|
||||||
train_data_dir_path = Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
logger.info(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
# できればcacheに依存せず明示的にダウンロードしたい
|
# できればcacheに依存せず明示的にダウンロードしたい
|
||||||
print(f"loading GIT: {args.model_id}")
|
logger.info(f"loading GIT: {args.model_id}")
|
||||||
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
||||||
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
||||||
print("GIT loaded")
|
logger.info("GIT loaded")
|
||||||
|
|
||||||
# captioningする
|
# captioningする
|
||||||
def run_batch(path_imgs):
|
def run_batch(path_imgs):
|
||||||
@@ -97,7 +100,7 @@ def main(args):
|
|||||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||||
f.write(caption + "\n")
|
f.write(caption + "\n")
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(image_path, caption)
|
logger.info(f"{image_path} {caption}")
|
||||||
|
|
||||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||||
if args.max_data_loader_n_workers is not None:
|
if args.max_data_loader_n_workers is not None:
|
||||||
@@ -126,7 +129,7 @@ def main(args):
|
|||||||
if image.mode != "RGB":
|
if image.mode != "RGB":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
b_imgs.append((image_path, image))
|
b_imgs.append((image_path, image))
|
||||||
@@ -137,7 +140,7 @@ def main(args):
|
|||||||
if len(b_imgs) > 0:
|
if len(b_imgs) > 0:
|
||||||
run_batch(b_imgs)
|
run_batch(b_imgs)
|
||||||
|
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -5,26 +5,30 @@ from typing import List
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import os
|
import os
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||||
|
|
||||||
train_data_dir_path = Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
logger.info(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if args.in_json is None and Path(args.out_json).is_file():
|
if args.in_json is None and Path(args.out_json).is_file():
|
||||||
args.in_json = args.out_json
|
args.in_json = args.out_json
|
||||||
|
|
||||||
if args.in_json is not None:
|
if args.in_json is not None:
|
||||||
print(f"loading existing metadata: {args.in_json}")
|
logger.info(f"loading existing metadata: {args.in_json}")
|
||||||
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
||||||
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
||||||
else:
|
else:
|
||||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
print("merge caption texts to metadata json.")
|
logger.info("merge caption texts to metadata json.")
|
||||||
for image_path in tqdm(image_paths):
|
for image_path in tqdm(image_paths):
|
||||||
caption_path = image_path.with_suffix(args.caption_extension)
|
caption_path = image_path.with_suffix(args.caption_extension)
|
||||||
caption = caption_path.read_text(encoding='utf-8').strip()
|
caption = caption_path.read_text(encoding='utf-8').strip()
|
||||||
@@ -38,12 +42,12 @@ def main(args):
|
|||||||
|
|
||||||
metadata[image_key]['caption'] = caption
|
metadata[image_key]['caption'] = caption
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(image_key, caption)
|
logger.info(f"{image_key} {caption}")
|
||||||
|
|
||||||
# metadataを書き出して終わり
|
# metadataを書き出して終わり
|
||||||
print(f"writing metadata: {args.out_json}")
|
logger.info(f"writing metadata: {args.out_json}")
|
||||||
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -5,26 +5,30 @@ from typing import List
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import os
|
import os
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||||
|
|
||||||
train_data_dir_path = Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
logger.info(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if args.in_json is None and Path(args.out_json).is_file():
|
if args.in_json is None and Path(args.out_json).is_file():
|
||||||
args.in_json = args.out_json
|
args.in_json = args.out_json
|
||||||
|
|
||||||
if args.in_json is not None:
|
if args.in_json is not None:
|
||||||
print(f"loading existing metadata: {args.in_json}")
|
logger.info(f"loading existing metadata: {args.in_json}")
|
||||||
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
||||||
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||||
else:
|
else:
|
||||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
print("merge tags to metadata json.")
|
logger.info("merge tags to metadata json.")
|
||||||
for image_path in tqdm(image_paths):
|
for image_path in tqdm(image_paths):
|
||||||
tags_path = image_path.with_suffix(args.caption_extension)
|
tags_path = image_path.with_suffix(args.caption_extension)
|
||||||
tags = tags_path.read_text(encoding='utf-8').strip()
|
tags = tags_path.read_text(encoding='utf-8').strip()
|
||||||
@@ -38,13 +42,13 @@ def main(args):
|
|||||||
|
|
||||||
metadata[image_key]['tags'] = tags
|
metadata[image_key]['tags'] = tags
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(image_key, tags)
|
logger.info(f"{image_key} {tags}")
|
||||||
|
|
||||||
# metadataを書き出して終わり
|
# metadataを書き出して終わり
|
||||||
print(f"writing metadata: {args.out_json}")
|
logger.info(f"writing metadata: {args.out_json}")
|
||||||
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
||||||
|
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ from torchvision import transforms
|
|||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
@@ -51,22 +55,22 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
|||||||
def main(args):
|
def main(args):
|
||||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||||
if args.bucket_reso_steps % 8 > 0:
|
if args.bucket_reso_steps % 8 > 0:
|
||||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||||
if args.bucket_reso_steps % 32 > 0:
|
if args.bucket_reso_steps % 32 > 0:
|
||||||
print(
|
logger.warning(
|
||||||
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
|
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
|
||||||
)
|
)
|
||||||
|
|
||||||
train_data_dir_path = Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
||||||
print(f"found {len(image_paths)} images.")
|
logger.info(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if os.path.exists(args.in_json):
|
if os.path.exists(args.in_json):
|
||||||
print(f"loading existing metadata: {args.in_json}")
|
logger.info(f"loading existing metadata: {args.in_json}")
|
||||||
with open(args.in_json, "rt", encoding="utf-8") as f:
|
with open(args.in_json, "rt", encoding="utf-8") as f:
|
||||||
metadata = json.load(f)
|
metadata = json.load(f)
|
||||||
else:
|
else:
|
||||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||||
return
|
return
|
||||||
|
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
@@ -89,7 +93,7 @@ def main(args):
|
|||||||
if not args.bucket_no_upscale:
|
if not args.bucket_no_upscale:
|
||||||
bucket_manager.make_buckets()
|
bucket_manager.make_buckets()
|
||||||
else:
|
else:
|
||||||
print(
|
logger.warning(
|
||||||
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
|
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -130,7 +134,7 @@ def main(args):
|
|||||||
if image.mode != "RGB":
|
if image.mode != "RGB":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||||
@@ -183,15 +187,15 @@ def main(args):
|
|||||||
for i, reso in enumerate(bucket_manager.resos):
|
for i, reso in enumerate(bucket_manager.resos):
|
||||||
count = bucket_counts.get(reso, 0)
|
count = bucket_counts.get(reso, 0)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
print(f"bucket {i} {reso}: {count}")
|
logger.info(f"bucket {i} {reso}: {count}")
|
||||||
img_ar_errors = np.array(img_ar_errors)
|
img_ar_errors = np.array(img_ar_errors)
|
||||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
logger.info(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||||
|
|
||||||
# metadataを書き出して終わり
|
# metadataを書き出して終わり
|
||||||
print(f"writing metadata: {args.out_json}")
|
logger.info(f"writing metadata: {args.out_json}")
|
||||||
with open(args.out_json, "wt", encoding="utf-8") as f:
|
with open(args.out_json, "wt", encoding="utf-8") as f:
|
||||||
json.dump(metadata, f, indent=2)
|
json.dump(metadata, f, indent=2)
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -11,6 +11,10 @@ from PIL import Image
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# from wd14 tagger
|
# from wd14 tagger
|
||||||
IMAGE_SIZE = 448
|
IMAGE_SIZE = 448
|
||||||
@@ -58,7 +62,7 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
|||||||
image = preprocess_image(image)
|
image = preprocess_image(image)
|
||||||
tensor = torch.tensor(image)
|
tensor = torch.tensor(image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return (tensor, img_path)
|
return (tensor, img_path)
|
||||||
@@ -79,7 +83,7 @@ def main(args):
|
|||||||
# depreacatedの警告が出るけどなくなったらその時
|
# depreacatedの警告が出るけどなくなったらその時
|
||||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||||
if not os.path.exists(args.model_dir) or args.force_download:
|
if not os.path.exists(args.model_dir) or args.force_download:
|
||||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||||
files = FILES
|
files = FILES
|
||||||
if args.onnx:
|
if args.onnx:
|
||||||
files += FILES_ONNX
|
files += FILES_ONNX
|
||||||
@@ -95,7 +99,7 @@ def main(args):
|
|||||||
force_filename=file,
|
force_filename=file,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("using existing wd14 tagger model")
|
logger.info("using existing wd14 tagger model")
|
||||||
|
|
||||||
# 画像を読み込む
|
# 画像を読み込む
|
||||||
if args.onnx:
|
if args.onnx:
|
||||||
@@ -103,8 +107,8 @@ def main(args):
|
|||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
onnx_path = f"{args.model_dir}/model.onnx"
|
onnx_path = f"{args.model_dir}/model.onnx"
|
||||||
print("Running wd14 tagger with onnx")
|
logger.info("Running wd14 tagger with onnx")
|
||||||
print(f"loading onnx model: {onnx_path}")
|
logger.info(f"loading onnx model: {onnx_path}")
|
||||||
|
|
||||||
if not os.path.exists(onnx_path):
|
if not os.path.exists(onnx_path):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@@ -121,7 +125,7 @@ def main(args):
|
|||||||
|
|
||||||
if args.batch_size != batch_size and type(batch_size) != str:
|
if args.batch_size != batch_size and type(batch_size) != str:
|
||||||
# some rebatch model may use 'N' as dynamic axes
|
# some rebatch model may use 'N' as dynamic axes
|
||||||
print(
|
logger.warning(
|
||||||
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||||
)
|
)
|
||||||
args.batch_size = batch_size
|
args.batch_size = batch_size
|
||||||
@@ -156,7 +160,7 @@ def main(args):
|
|||||||
|
|
||||||
train_data_dir_path = Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
logger.info(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
tag_freq = {}
|
tag_freq = {}
|
||||||
|
|
||||||
@@ -237,7 +241,10 @@ def main(args):
|
|||||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||||
f.write(tag_text + "\n")
|
f.write(tag_text + "\n")
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
logger.info("")
|
||||||
|
logger.info(f"{image_path}:")
|
||||||
|
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||||
|
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||||
|
|
||||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||||
if args.max_data_loader_n_workers is not None:
|
if args.max_data_loader_n_workers is not None:
|
||||||
@@ -269,7 +276,7 @@ def main(args):
|
|||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
image = preprocess_image(image)
|
image = preprocess_image(image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||||
continue
|
continue
|
||||||
b_imgs.append((image_path, image))
|
b_imgs.append((image_path, image))
|
||||||
|
|
||||||
@@ -284,11 +291,11 @@ def main(args):
|
|||||||
|
|
||||||
if args.frequency_tags:
|
if args.frequency_tags:
|
||||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||||
print("\nTag frequencies:")
|
logger.info("Tag frequencies:")
|
||||||
for tag, freq in sorted_tags:
|
for tag, freq in sorted_tags:
|
||||||
print(f"{tag}: {freq}")
|
logger.info(f"{tag}: {freq}")
|
||||||
|
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -104,6 +104,10 @@ from library.original_unet import UNet2DConditionModel, InferUNet2DConditionMode
|
|||||||
from library.original_unet import FlashAttentionFunction
|
from library.original_unet import FlashAttentionFunction
|
||||||
|
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# scheduler:
|
# scheduler:
|
||||||
SCHEDULER_LINEAR_START = 0.00085
|
SCHEDULER_LINEAR_START = 0.00085
|
||||||
@@ -139,12 +143,12 @@ USE_CUTOUTS = False
|
|||||||
|
|
||||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
print("Enable memory efficient attention for U-Net")
|
logger.info("Enable memory efficient attention for U-Net")
|
||||||
|
|
||||||
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
|
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
|
||||||
unet.set_use_memory_efficient_attention(False, True)
|
unet.set_use_memory_efficient_attention(False, True)
|
||||||
elif xformers:
|
elif xformers:
|
||||||
print("Enable xformers for U-Net")
|
logger.info("Enable xformers for U-Net")
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -152,7 +156,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
|||||||
|
|
||||||
unet.set_use_memory_efficient_attention(True, False)
|
unet.set_use_memory_efficient_attention(True, False)
|
||||||
elif sdpa:
|
elif sdpa:
|
||||||
print("Enable SDPA for U-Net")
|
logger.info("Enable SDPA for U-Net")
|
||||||
unet.set_use_memory_efficient_attention(False, False)
|
unet.set_use_memory_efficient_attention(False, False)
|
||||||
unet.set_use_sdpa(True)
|
unet.set_use_sdpa(True)
|
||||||
|
|
||||||
@@ -168,7 +172,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
|
|||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_memory_efficient():
|
def replace_vae_attn_to_memory_efficient():
|
||||||
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
||||||
flash_func = FlashAttentionFunction
|
flash_func = FlashAttentionFunction
|
||||||
|
|
||||||
def forward_flash_attn(self, hidden_states, **kwargs):
|
def forward_flash_attn(self, hidden_states, **kwargs):
|
||||||
@@ -224,7 +228,7 @@ def replace_vae_attn_to_memory_efficient():
|
|||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_xformers():
|
def replace_vae_attn_to_xformers():
|
||||||
print("VAE: Attention.forward has been replaced to xformers")
|
logger.info("VAE: Attention.forward has been replaced to xformers")
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
def forward_xformers(self, hidden_states, **kwargs):
|
def forward_xformers(self, hidden_states, **kwargs):
|
||||||
@@ -280,7 +284,7 @@ def replace_vae_attn_to_xformers():
|
|||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_sdpa():
|
def replace_vae_attn_to_sdpa():
|
||||||
print("VAE: Attention.forward has been replaced to sdpa")
|
logger.info("VAE: Attention.forward has been replaced to sdpa")
|
||||||
|
|
||||||
def forward_sdpa(self, hidden_states, **kwargs):
|
def forward_sdpa(self, hidden_states, **kwargs):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -684,7 +688,7 @@ class PipelineLike:
|
|||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
if not do_classifier_free_guidance and negative_scale is not None:
|
if not do_classifier_free_guidance and negative_scale is not None:
|
||||||
print(f"negative_scale is ignored if guidance scalle <= 1.0")
|
logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0")
|
||||||
negative_scale = None
|
negative_scale = None
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidance
|
# get unconditional embeddings for classifier free guidance
|
||||||
@@ -766,11 +770,11 @@ class PipelineLike:
|
|||||||
clip_text_input = prompt_tokens
|
clip_text_input = prompt_tokens
|
||||||
if clip_text_input.shape[1] > self.tokenizer.model_max_length:
|
if clip_text_input.shape[1] > self.tokenizer.model_max_length:
|
||||||
# TODO 75文字を超えたら警告を出す?
|
# TODO 75文字を超えたら警告を出す?
|
||||||
print("trim text input", clip_text_input.shape)
|
logger.info(f"trim text input {clip_text_input.shape}")
|
||||||
clip_text_input = torch.cat(
|
clip_text_input = torch.cat(
|
||||||
[clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1
|
[clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1
|
||||||
)
|
)
|
||||||
print("trimmed", clip_text_input.shape)
|
logger.info(f"trimmed {clip_text_input.shape}")
|
||||||
|
|
||||||
for i, clip_prompt in enumerate(clip_prompts):
|
for i, clip_prompt in enumerate(clip_prompts):
|
||||||
if clip_prompt is not None: # clip_promptがあれば上書きする
|
if clip_prompt is not None: # clip_promptがあれば上書きする
|
||||||
@@ -1699,7 +1703,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
|||||||
if word.strip() == "BREAK":
|
if word.strip() == "BREAK":
|
||||||
# pad until next multiple of tokenizer's max token length
|
# pad until next multiple of tokenizer's max token length
|
||||||
pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length)
|
pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length)
|
||||||
print(f"BREAK pad_len: {pad_len}")
|
logger.info(f"BREAK pad_len: {pad_len}")
|
||||||
for i in range(pad_len):
|
for i in range(pad_len):
|
||||||
# v2のときEOSをつけるべきかどうかわからないぜ
|
# v2のときEOSをつけるべきかどうかわからないぜ
|
||||||
# if i == 0:
|
# if i == 0:
|
||||||
@@ -1729,7 +1733,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
|||||||
tokens.append(text_token)
|
tokens.append(text_token)
|
||||||
weights.append(text_weight)
|
weights.append(text_weight)
|
||||||
if truncated:
|
if truncated:
|
||||||
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||||
return tokens, weights
|
return tokens, weights
|
||||||
|
|
||||||
|
|
||||||
@@ -2041,7 +2045,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
|
|||||||
elif len(count_range) == 2:
|
elif len(count_range) == 2:
|
||||||
count_range = [int(count_range[0]), int(count_range[1])]
|
count_range = [int(count_range[0]), int(count_range[1])]
|
||||||
else:
|
else:
|
||||||
print(f"invalid count range: {count_range}")
|
logger.warning(f"invalid count range: {count_range}")
|
||||||
count_range = [1, 1]
|
count_range = [1, 1]
|
||||||
if count_range[0] > count_range[1]:
|
if count_range[0] > count_range[1]:
|
||||||
count_range = [count_range[1], count_range[0]]
|
count_range = [count_range[1], count_range[0]]
|
||||||
@@ -2111,7 +2115,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
|
|||||||
|
|
||||||
|
|
||||||
# def load_clip_l14_336(dtype):
|
# def load_clip_l14_336(dtype):
|
||||||
# print(f"loading CLIP: {CLIP_ID_L14_336}")
|
# logger.info(f"loading CLIP: {CLIP_ID_L14_336}")
|
||||||
# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
|
# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
|
||||||
# return text_encoder
|
# return text_encoder
|
||||||
|
|
||||||
@@ -2158,9 +2162,9 @@ def main(args):
|
|||||||
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
||||||
|
|
||||||
if args.v_parameterization and not args.v2:
|
if args.v_parameterization and not args.v2:
|
||||||
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
||||||
if args.v2 and args.clip_skip is not None:
|
if args.v2 and args.clip_skip is not None:
|
||||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
|
if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
|
||||||
@@ -2170,10 +2174,10 @@ def main(args):
|
|||||||
|
|
||||||
use_stable_diffusion_format = os.path.isfile(args.ckpt)
|
use_stable_diffusion_format = os.path.isfile(args.ckpt)
|
||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
print("load StableDiffusion checkpoint")
|
logger.info("load StableDiffusion checkpoint")
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
|
||||||
else:
|
else:
|
||||||
print("load Diffusers pretrained models")
|
logger.info("load Diffusers pretrained models")
|
||||||
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
||||||
text_encoder = loading_pipe.text_encoder
|
text_encoder = loading_pipe.text_encoder
|
||||||
vae = loading_pipe.vae
|
vae = loading_pipe.vae
|
||||||
@@ -2196,21 +2200,21 @@ def main(args):
|
|||||||
# VAEを読み込む
|
# VAEを読み込む
|
||||||
if args.vae is not None:
|
if args.vae is not None:
|
||||||
vae = model_util.load_vae(args.vae, dtype)
|
vae = model_util.load_vae(args.vae, dtype)
|
||||||
print("additional VAE loaded")
|
logger.info("additional VAE loaded")
|
||||||
|
|
||||||
# # 置換するCLIPを読み込む
|
# # 置換するCLIPを読み込む
|
||||||
# if args.replace_clip_l14_336:
|
# if args.replace_clip_l14_336:
|
||||||
# text_encoder = load_clip_l14_336(dtype)
|
# text_encoder = load_clip_l14_336(dtype)
|
||||||
# print(f"large clip {CLIP_ID_L14_336} is loaded")
|
# logger.info(f"large clip {CLIP_ID_L14_336} is loaded")
|
||||||
|
|
||||||
if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale:
|
if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale:
|
||||||
print("prepare clip model")
|
logger.info("prepare clip model")
|
||||||
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype)
|
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype)
|
||||||
else:
|
else:
|
||||||
clip_model = None
|
clip_model = None
|
||||||
|
|
||||||
if args.vgg16_guidance_scale > 0.0:
|
if args.vgg16_guidance_scale > 0.0:
|
||||||
print("prepare resnet model")
|
logger.info("prepare resnet model")
|
||||||
vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1)
|
vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1)
|
||||||
else:
|
else:
|
||||||
vgg16_model = None
|
vgg16_model = None
|
||||||
@@ -2222,7 +2226,7 @@ def main(args):
|
|||||||
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
|
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
|
||||||
|
|
||||||
# tokenizerを読み込む
|
# tokenizerを読み込む
|
||||||
print("loading tokenizer")
|
logger.info("loading tokenizer")
|
||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
|
||||||
@@ -2281,7 +2285,7 @@ def main(args):
|
|||||||
self.sampler_noises = noises
|
self.sampler_noises = noises
|
||||||
|
|
||||||
def randn(self, shape, device=None, dtype=None, layout=None, generator=None):
|
def randn(self, shape, device=None, dtype=None, layout=None, generator=None):
|
||||||
# print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index)
|
# logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}")
|
||||||
if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises):
|
if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises):
|
||||||
noise = self.sampler_noises[self.sampler_noise_index]
|
noise = self.sampler_noises[self.sampler_noise_index]
|
||||||
if shape != noise.shape:
|
if shape != noise.shape:
|
||||||
@@ -2290,7 +2294,7 @@ def main(args):
|
|||||||
noise = None
|
noise = None
|
||||||
|
|
||||||
if noise == None:
|
if noise == None:
|
||||||
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
|
logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
|
||||||
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
|
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
|
||||||
|
|
||||||
self.sampler_noise_index += 1
|
self.sampler_noise_index += 1
|
||||||
@@ -2321,7 +2325,7 @@ def main(args):
|
|||||||
|
|
||||||
# clip_sample=Trueにする
|
# clip_sample=Trueにする
|
||||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||||
print("set clip_sample to True")
|
logger.info("set clip_sample to True")
|
||||||
scheduler.config.clip_sample = True
|
scheduler.config.clip_sample = True
|
||||||
|
|
||||||
# deviceを決定する
|
# deviceを決定する
|
||||||
@@ -2378,7 +2382,7 @@ def main(args):
|
|||||||
network_merge = 0
|
network_merge = 0
|
||||||
|
|
||||||
for i, network_module in enumerate(args.network_module):
|
for i, network_module in enumerate(args.network_module):
|
||||||
print("import network module:", network_module)
|
logger.info(f"import network module: {network_module}")
|
||||||
imported_module = importlib.import_module(network_module)
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
@@ -2396,7 +2400,7 @@ def main(args):
|
|||||||
raise ValueError("No weight. Weight is required.")
|
raise ValueError("No weight. Weight is required.")
|
||||||
|
|
||||||
network_weight = args.network_weights[i]
|
network_weight = args.network_weights[i]
|
||||||
print("load network weights from:", network_weight)
|
logger.info(f"load network weights from: {network_weight}")
|
||||||
|
|
||||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
@@ -2404,7 +2408,7 @@ def main(args):
|
|||||||
with safe_open(network_weight, framework="pt") as f:
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
print(f"metadata for: {network_weight}: {metadata}")
|
logger.info(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
network, weights_sd = imported_module.create_network_from_weights(
|
network, weights_sd = imported_module.create_network_from_weights(
|
||||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||||
@@ -2414,20 +2418,20 @@ def main(args):
|
|||||||
|
|
||||||
mergeable = network.is_mergeable()
|
mergeable = network.is_mergeable()
|
||||||
if network_merge and not mergeable:
|
if network_merge and not mergeable:
|
||||||
print("network is not mergiable. ignore merge option.")
|
logger.warning("network is not mergiable. ignore merge option.")
|
||||||
|
|
||||||
if not mergeable or i >= network_merge:
|
if not mergeable or i >= network_merge:
|
||||||
# not merging
|
# not merging
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||||
print(f"weights are loaded: {info}")
|
logger.info(f"weights are loaded: {info}")
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
network.to(memory_format=torch.channels_last)
|
network.to(memory_format=torch.channels_last)
|
||||||
network.to(dtype).to(device)
|
network.to(dtype).to(device)
|
||||||
|
|
||||||
if network_pre_calc:
|
if network_pre_calc:
|
||||||
print("backup original weights")
|
logger.info("backup original weights")
|
||||||
network.backup_weights()
|
network.backup_weights()
|
||||||
|
|
||||||
networks.append(network)
|
networks.append(network)
|
||||||
@@ -2441,7 +2445,7 @@ def main(args):
|
|||||||
# upscalerの指定があれば取得する
|
# upscalerの指定があれば取得する
|
||||||
upscaler = None
|
upscaler = None
|
||||||
if args.highres_fix_upscaler:
|
if args.highres_fix_upscaler:
|
||||||
print("import upscaler module:", args.highres_fix_upscaler)
|
logger.info(f"import upscaler module {args.highres_fix_upscaler}")
|
||||||
imported_module = importlib.import_module(args.highres_fix_upscaler)
|
imported_module = importlib.import_module(args.highres_fix_upscaler)
|
||||||
|
|
||||||
us_kwargs = {}
|
us_kwargs = {}
|
||||||
@@ -2450,7 +2454,7 @@ def main(args):
|
|||||||
key, value = net_arg.split("=")
|
key, value = net_arg.split("=")
|
||||||
us_kwargs[key] = value
|
us_kwargs[key] = value
|
||||||
|
|
||||||
print("create upscaler")
|
logger.info("create upscaler")
|
||||||
upscaler = imported_module.create_upscaler(**us_kwargs)
|
upscaler = imported_module.create_upscaler(**us_kwargs)
|
||||||
upscaler.to(dtype).to(device)
|
upscaler.to(dtype).to(device)
|
||||||
|
|
||||||
@@ -2467,7 +2471,7 @@ def main(args):
|
|||||||
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
print(f"set optimizing: channels last")
|
logger.info(f"set optimizing: channels last")
|
||||||
text_encoder.to(memory_format=torch.channels_last)
|
text_encoder.to(memory_format=torch.channels_last)
|
||||||
vae.to(memory_format=torch.channels_last)
|
vae.to(memory_format=torch.channels_last)
|
||||||
unet.to(memory_format=torch.channels_last)
|
unet.to(memory_format=torch.channels_last)
|
||||||
@@ -2499,7 +2503,7 @@ def main(args):
|
|||||||
args.vgg16_guidance_layer,
|
args.vgg16_guidance_layer,
|
||||||
)
|
)
|
||||||
pipe.set_control_nets(control_nets)
|
pipe.set_control_nets(control_nets)
|
||||||
print("pipeline is ready.")
|
logger.info("pipeline is ready.")
|
||||||
|
|
||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
@@ -2542,7 +2546,7 @@ def main(args):
|
|||||||
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||||
|
|
||||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
||||||
assert (
|
assert (
|
||||||
min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
|
min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
|
||||||
), f"token ids is not ordered"
|
), f"token ids is not ordered"
|
||||||
@@ -2601,7 +2605,7 @@ def main(args):
|
|||||||
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||||
|
|
||||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
logger.info(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
||||||
|
|
||||||
# if num_vectors_per_token > 1:
|
# if num_vectors_per_token > 1:
|
||||||
pipe.add_token_replacement(token_ids[0], token_ids)
|
pipe.add_token_replacement(token_ids[0], token_ids)
|
||||||
@@ -2626,7 +2630,7 @@ def main(args):
|
|||||||
|
|
||||||
# promptを取得する
|
# promptを取得する
|
||||||
if args.from_file is not None:
|
if args.from_file is not None:
|
||||||
print(f"reading prompts from {args.from_file}")
|
logger.info(f"reading prompts from {args.from_file}")
|
||||||
with open(args.from_file, "r", encoding="utf-8") as f:
|
with open(args.from_file, "r", encoding="utf-8") as f:
|
||||||
prompt_list = f.read().splitlines()
|
prompt_list = f.read().splitlines()
|
||||||
prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
|
prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
|
||||||
@@ -2655,7 +2659,7 @@ def main(args):
|
|||||||
for p in paths:
|
for p in paths:
|
||||||
image = Image.open(p)
|
image = Image.open(p)
|
||||||
if image.mode != "RGB":
|
if image.mode != "RGB":
|
||||||
print(f"convert image to RGB from {image.mode}: {p}")
|
logger.info(f"convert image to RGB from {image.mode}: {p}")
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
|
||||||
@@ -2671,24 +2675,24 @@ def main(args):
|
|||||||
return resized
|
return resized
|
||||||
|
|
||||||
if args.image_path is not None:
|
if args.image_path is not None:
|
||||||
print(f"load image for img2img: {args.image_path}")
|
logger.info(f"load image for img2img: {args.image_path}")
|
||||||
init_images = load_images(args.image_path)
|
init_images = load_images(args.image_path)
|
||||||
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
|
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
|
||||||
print(f"loaded {len(init_images)} images for img2img")
|
logger.info(f"loaded {len(init_images)} images for img2img")
|
||||||
else:
|
else:
|
||||||
init_images = None
|
init_images = None
|
||||||
|
|
||||||
if args.mask_path is not None:
|
if args.mask_path is not None:
|
||||||
print(f"load mask for inpainting: {args.mask_path}")
|
logger.info(f"load mask for inpainting: {args.mask_path}")
|
||||||
mask_images = load_images(args.mask_path)
|
mask_images = load_images(args.mask_path)
|
||||||
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
|
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
|
||||||
print(f"loaded {len(mask_images)} mask images for inpainting")
|
logger.info(f"loaded {len(mask_images)} mask images for inpainting")
|
||||||
else:
|
else:
|
||||||
mask_images = None
|
mask_images = None
|
||||||
|
|
||||||
# promptがないとき、画像のPngInfoから取得する
|
# promptがないとき、画像のPngInfoから取得する
|
||||||
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
|
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
|
||||||
print("get prompts from images' meta data")
|
logger.info("get prompts from images' meta data")
|
||||||
for img in init_images:
|
for img in init_images:
|
||||||
if "prompt" in img.text:
|
if "prompt" in img.text:
|
||||||
prompt = img.text["prompt"]
|
prompt = img.text["prompt"]
|
||||||
@@ -2717,17 +2721,17 @@ def main(args):
|
|||||||
h = int(h * args.highres_fix_scale + 0.5)
|
h = int(h * args.highres_fix_scale + 0.5)
|
||||||
|
|
||||||
if init_images is not None:
|
if init_images is not None:
|
||||||
print(f"resize img2img source images to {w}*{h}")
|
logger.info(f"resize img2img source images to {w}*{h}")
|
||||||
init_images = resize_images(init_images, (w, h))
|
init_images = resize_images(init_images, (w, h))
|
||||||
if mask_images is not None:
|
if mask_images is not None:
|
||||||
print(f"resize img2img mask images to {w}*{h}")
|
logger.info(f"resize img2img mask images to {w}*{h}")
|
||||||
mask_images = resize_images(mask_images, (w, h))
|
mask_images = resize_images(mask_images, (w, h))
|
||||||
|
|
||||||
regional_network = False
|
regional_network = False
|
||||||
if networks and mask_images:
|
if networks and mask_images:
|
||||||
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
||||||
regional_network = True
|
regional_network = True
|
||||||
print("use mask as region")
|
logger.info("use mask as region")
|
||||||
|
|
||||||
size = None
|
size = None
|
||||||
for i, network in enumerate(networks):
|
for i, network in enumerate(networks):
|
||||||
@@ -2752,14 +2756,14 @@ def main(args):
|
|||||||
|
|
||||||
prev_image = None # for VGG16 guided
|
prev_image = None # for VGG16 guided
|
||||||
if args.guide_image_path is not None:
|
if args.guide_image_path is not None:
|
||||||
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
logger.info(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
||||||
guide_images = []
|
guide_images = []
|
||||||
for p in args.guide_image_path:
|
for p in args.guide_image_path:
|
||||||
guide_images.extend(load_images(p))
|
guide_images.extend(load_images(p))
|
||||||
|
|
||||||
print(f"loaded {len(guide_images)} guide images for guidance")
|
logger.info(f"loaded {len(guide_images)} guide images for guidance")
|
||||||
if len(guide_images) == 0:
|
if len(guide_images) == 0:
|
||||||
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
logger.info(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
||||||
guide_images = None
|
guide_images = None
|
||||||
else:
|
else:
|
||||||
guide_images = None
|
guide_images = None
|
||||||
@@ -2785,7 +2789,7 @@ def main(args):
|
|||||||
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
||||||
|
|
||||||
for gen_iter in range(args.n_iter):
|
for gen_iter in range(args.n_iter):
|
||||||
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
logger.info(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||||
iter_seed = random.randint(0, 0x7FFFFFFF)
|
iter_seed = random.randint(0, 0x7FFFFFFF)
|
||||||
|
|
||||||
# shuffle prompt list
|
# shuffle prompt list
|
||||||
@@ -2801,7 +2805,7 @@ def main(args):
|
|||||||
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
||||||
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
|
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
|
||||||
|
|
||||||
print("process 1st stage")
|
logger.info("process 1st stage")
|
||||||
batch_1st = []
|
batch_1st = []
|
||||||
for _, base, ext in batch:
|
for _, base, ext in batch:
|
||||||
width_1st = int(ext.width * args.highres_fix_scale + 0.5)
|
width_1st = int(ext.width * args.highres_fix_scale + 0.5)
|
||||||
@@ -2827,7 +2831,7 @@ def main(args):
|
|||||||
images_1st = process_batch(batch_1st, True, True)
|
images_1st = process_batch(batch_1st, True, True)
|
||||||
|
|
||||||
# 2nd stageのバッチを作成して以下処理する
|
# 2nd stageのバッチを作成して以下処理する
|
||||||
print("process 2nd stage")
|
logger.info("process 2nd stage")
|
||||||
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
|
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
|
||||||
|
|
||||||
if upscaler:
|
if upscaler:
|
||||||
@@ -2978,7 +2982,7 @@ def main(args):
|
|||||||
n.restore_weights()
|
n.restore_weights()
|
||||||
for n in networks:
|
for n in networks:
|
||||||
n.pre_calculation()
|
n.pre_calculation()
|
||||||
print("pre-calculation... done")
|
logger.info("pre-calculation... done")
|
||||||
|
|
||||||
images = pipe(
|
images = pipe(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -3045,7 +3049,7 @@ def main(args):
|
|||||||
cv2.waitKey()
|
cv2.waitKey()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
|
logger.info("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
@@ -3058,7 +3062,8 @@ def main(args):
|
|||||||
# interactive
|
# interactive
|
||||||
valid = False
|
valid = False
|
||||||
while not valid:
|
while not valid:
|
||||||
print("\nType prompt:")
|
logger.info("")
|
||||||
|
logger.info("Type prompt:")
|
||||||
try:
|
try:
|
||||||
raw_prompt = input()
|
raw_prompt = input()
|
||||||
except EOFError:
|
except EOFError:
|
||||||
@@ -3101,38 +3106,38 @@ def main(args):
|
|||||||
|
|
||||||
prompt_args = raw_prompt.strip().split(" --")
|
prompt_args = raw_prompt.strip().split(" --")
|
||||||
prompt = prompt_args[0]
|
prompt = prompt_args[0]
|
||||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||||
|
|
||||||
for parg in prompt_args[1:]:
|
for parg in prompt_args[1:]:
|
||||||
try:
|
try:
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
width = int(m.group(1))
|
width = int(m.group(1))
|
||||||
print(f"width: {width}")
|
logger.info(f"width: {width}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
height = int(m.group(1))
|
height = int(m.group(1))
|
||||||
print(f"height: {height}")
|
logger.info(f"height: {height}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||||
if m: # steps
|
if m: # steps
|
||||||
steps = max(1, min(1000, int(m.group(1))))
|
steps = max(1, min(1000, int(m.group(1))))
|
||||||
print(f"steps: {steps}")
|
logger.info(f"steps: {steps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||||
if m: # seed
|
if m: # seed
|
||||||
seeds = [int(d) for d in m.group(1).split(",")]
|
seeds = [int(d) for d in m.group(1).split(",")]
|
||||||
print(f"seeds: {seeds}")
|
logger.info(f"seeds: {seeds}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # scale
|
if m: # scale
|
||||||
scale = float(m.group(1))
|
scale = float(m.group(1))
|
||||||
print(f"scale: {scale}")
|
logger.info(f"scale: {scale}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
||||||
@@ -3141,25 +3146,25 @@ def main(args):
|
|||||||
negative_scale = None
|
negative_scale = None
|
||||||
else:
|
else:
|
||||||
negative_scale = float(m.group(1))
|
negative_scale = float(m.group(1))
|
||||||
print(f"negative scale: {negative_scale}")
|
logger.info(f"negative scale: {negative_scale}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # strength
|
if m: # strength
|
||||||
strength = float(m.group(1))
|
strength = float(m.group(1))
|
||||||
print(f"strength: {strength}")
|
logger.info(f"strength: {strength}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||||
if m: # negative prompt
|
if m: # negative prompt
|
||||||
negative_prompt = m.group(1)
|
negative_prompt = m.group(1)
|
||||||
print(f"negative prompt: {negative_prompt}")
|
logger.info(f"negative prompt: {negative_prompt}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
||||||
if m: # clip prompt
|
if m: # clip prompt
|
||||||
clip_prompt = m.group(1)
|
clip_prompt = m.group(1)
|
||||||
print(f"clip prompt: {clip_prompt}")
|
logger.info(f"clip prompt: {clip_prompt}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||||
@@ -3167,47 +3172,47 @@ def main(args):
|
|||||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||||
while len(network_muls) < len(networks):
|
while len(network_muls) < len(networks):
|
||||||
network_muls.append(network_muls[-1])
|
network_muls.append(network_muls[-1])
|
||||||
print(f"network mul: {network_muls}")
|
logger.info(f"network mul: {network_muls}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Deep Shrink
|
# Deep Shrink
|
||||||
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink depth 1
|
if m: # deep shrink depth 1
|
||||||
ds_depth_1 = int(m.group(1))
|
ds_depth_1 = int(m.group(1))
|
||||||
print(f"deep shrink depth 1: {ds_depth_1}")
|
logger.info(f"deep shrink depth 1: {ds_depth_1}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink timesteps 1
|
if m: # deep shrink timesteps 1
|
||||||
ds_timesteps_1 = int(m.group(1))
|
ds_timesteps_1 = int(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink depth 2
|
if m: # deep shrink depth 2
|
||||||
ds_depth_2 = int(m.group(1))
|
ds_depth_2 = int(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink depth 2: {ds_depth_2}")
|
logger.info(f"deep shrink depth 2: {ds_depth_2}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink timesteps 2
|
if m: # deep shrink timesteps 2
|
||||||
ds_timesteps_2 = int(m.group(1))
|
ds_timesteps_2 = int(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink ratio
|
if m: # deep shrink ratio
|
||||||
ds_ratio = float(m.group(1))
|
ds_ratio = float(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink ratio: {ds_ratio}")
|
logger.info(f"deep shrink ratio: {ds_ratio}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
logger.info(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
print(ex)
|
logger.info(ex)
|
||||||
|
|
||||||
# override Deep Shrink
|
# override Deep Shrink
|
||||||
if ds_depth_1 is not None:
|
if ds_depth_1 is not None:
|
||||||
@@ -3225,7 +3230,7 @@ def main(args):
|
|||||||
if len(predefined_seeds) > 0:
|
if len(predefined_seeds) > 0:
|
||||||
seed = predefined_seeds.pop(0)
|
seed = predefined_seeds.pop(0)
|
||||||
else:
|
else:
|
||||||
print("predefined seeds are exhausted")
|
logger.info("predefined seeds are exhausted")
|
||||||
seed = None
|
seed = None
|
||||||
elif args.iter_same_seed:
|
elif args.iter_same_seed:
|
||||||
seed = iter_seed
|
seed = iter_seed
|
||||||
@@ -3235,7 +3240,7 @@ def main(args):
|
|||||||
if seed is None:
|
if seed is None:
|
||||||
seed = random.randint(0, 0x7FFFFFFF)
|
seed = random.randint(0, 0x7FFFFFFF)
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
print(f"seed: {seed}")
|
logger.info(f"seed: {seed}")
|
||||||
|
|
||||||
# prepare init image, guide image and mask
|
# prepare init image, guide image and mask
|
||||||
init_image = mask_image = guide_image = None
|
init_image = mask_image = guide_image = None
|
||||||
@@ -3251,7 +3256,7 @@ def main(args):
|
|||||||
width = width - width % 32
|
width = width - width % 32
|
||||||
height = height - height % 32
|
height = height - height % 32
|
||||||
if width != init_image.size[0] or height != init_image.size[1]:
|
if width != init_image.size[0] or height != init_image.size[1]:
|
||||||
print(
|
logger.info(
|
||||||
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -3267,9 +3272,9 @@ def main(args):
|
|||||||
guide_image = guide_images[global_step % len(guide_images)]
|
guide_image = guide_images[global_step % len(guide_images)]
|
||||||
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
||||||
if prev_image is None:
|
if prev_image is None:
|
||||||
print("Generate 1st image without guide image.")
|
logger.info("Generate 1st image without guide image.")
|
||||||
else:
|
else:
|
||||||
print("Use previous image as guide image.")
|
logger.info("Use previous image as guide image.")
|
||||||
guide_image = prev_image
|
guide_image = prev_image
|
||||||
|
|
||||||
if regional_network:
|
if regional_network:
|
||||||
@@ -3311,7 +3316,7 @@ def main(args):
|
|||||||
process_batch(batch_data, highres_fix)
|
process_batch(batch_data, highres_fix)
|
||||||
batch_data.clear()
|
batch_data.clear()
|
||||||
|
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -40,7 +40,10 @@ from .train_util import (
|
|||||||
ControlNetDataset,
|
ControlNetDataset,
|
||||||
DatasetGroup,
|
DatasetGroup,
|
||||||
)
|
)
|
||||||
|
from .utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def add_config_arguments(parser: argparse.ArgumentParser):
|
def add_config_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
||||||
@@ -345,7 +348,7 @@ class ConfigSanitizer:
|
|||||||
return self.user_config_validator(user_config)
|
return self.user_config_validator(user_config)
|
||||||
except MultipleInvalid:
|
except MultipleInvalid:
|
||||||
# TODO: エラー発生時のメッセージをわかりやすくする
|
# TODO: エラー発生時のメッセージをわかりやすくする
|
||||||
print("Invalid user config / ユーザ設定の形式が正しくないようです")
|
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# NOTE: In nature, argument parser result is not needed to be sanitize
|
# NOTE: In nature, argument parser result is not needed to be sanitize
|
||||||
@@ -355,7 +358,7 @@ class ConfigSanitizer:
|
|||||||
return self.argparse_config_validator(argparse_namespace)
|
return self.argparse_config_validator(argparse_namespace)
|
||||||
except MultipleInvalid:
|
except MultipleInvalid:
|
||||||
# XXX: this should be a bug
|
# XXX: this should be a bug
|
||||||
print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# NOTE: value would be overwritten by latter dict if there is already the same key
|
# NOTE: value would be overwritten by latter dict if there is already the same key
|
||||||
@@ -538,13 +541,13 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
|||||||
" ",
|
" ",
|
||||||
)
|
)
|
||||||
|
|
||||||
print(info)
|
logger.info(f'{info}')
|
||||||
|
|
||||||
# make buckets first because it determines the length of dataset
|
# make buckets first because it determines the length of dataset
|
||||||
# and set the same seed for all datasets
|
# and set the same seed for all datasets
|
||||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||||
for i, dataset in enumerate(datasets):
|
for i, dataset in enumerate(datasets):
|
||||||
print(f"[Dataset {i}]")
|
logger.info(f"[Dataset {i}]")
|
||||||
dataset.make_buckets()
|
dataset.make_buckets()
|
||||||
dataset.set_seed(seed)
|
dataset.set_seed(seed)
|
||||||
|
|
||||||
@@ -557,7 +560,7 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
|
|||||||
try:
|
try:
|
||||||
n_repeats = int(tokens[0])
|
n_repeats = int(tokens[0])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
||||||
return 0, ""
|
return 0, ""
|
||||||
caption_by_folder = "_".join(tokens[1:])
|
caption_by_folder = "_".join(tokens[1:])
|
||||||
return n_repeats, caption_by_folder
|
return n_repeats, caption_by_folder
|
||||||
@@ -629,17 +632,13 @@ def load_user_config(file: str) -> dict:
|
|||||||
with open(file, "r") as f:
|
with open(file, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(
|
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||||
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
elif file.name.lower().endswith(".toml"):
|
elif file.name.lower().endswith(".toml"):
|
||||||
try:
|
try:
|
||||||
config = toml.load(file)
|
config = toml.load(file)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(
|
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||||
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
||||||
@@ -665,23 +664,26 @@ if __name__ == "__main__":
|
|||||||
argparse_namespace = parser.parse_args(remain)
|
argparse_namespace = parser.parse_args(remain)
|
||||||
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
||||||
|
|
||||||
print("[argparse_namespace]")
|
logger.info("[argparse_namespace]")
|
||||||
print(vars(argparse_namespace))
|
logger.info(f'{vars(argparse_namespace)}')
|
||||||
|
|
||||||
user_config = load_user_config(config_args.dataset_config)
|
user_config = load_user_config(config_args.dataset_config)
|
||||||
|
|
||||||
print("\n[user_config]")
|
logger.info("")
|
||||||
print(user_config)
|
logger.info("[user_config]")
|
||||||
|
logger.info(f'{user_config}')
|
||||||
|
|
||||||
sanitizer = ConfigSanitizer(
|
sanitizer = ConfigSanitizer(
|
||||||
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
||||||
)
|
)
|
||||||
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
||||||
|
|
||||||
print("\n[sanitized_user_config]")
|
logger.info("")
|
||||||
print(sanitized_user_config)
|
logger.info("[sanitized_user_config]")
|
||||||
|
logger.info(f'{sanitized_user_config}')
|
||||||
|
|
||||||
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
||||||
|
|
||||||
print("\n[blueprint]")
|
logger.info("")
|
||||||
print(blueprint)
|
logger.info("[blueprint]")
|
||||||
|
logger.info(f'{blueprint}')
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ import argparse
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
from .utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||||
if hasattr(noise_scheduler, "all_snr"):
|
if hasattr(noise_scheduler, "all_snr"):
|
||||||
@@ -21,7 +24,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
|||||||
|
|
||||||
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||||
# fix beta: zero terminal SNR
|
# fix beta: zero terminal SNR
|
||||||
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
||||||
|
|
||||||
def enforce_zero_terminal_snr(betas):
|
def enforce_zero_terminal_snr(betas):
|
||||||
# Convert betas to alphas_bar_sqrt
|
# Convert betas to alphas_bar_sqrt
|
||||||
@@ -49,8 +52,8 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
|||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
# print("original:", noise_scheduler.betas)
|
# logger.info(f"original: {noise_scheduler.betas}")
|
||||||
# print("fixed:", betas)
|
# logger.info(f"fixed: {betas}")
|
||||||
|
|
||||||
noise_scheduler.betas = betas
|
noise_scheduler.betas = betas
|
||||||
noise_scheduler.alphas = alphas
|
noise_scheduler.alphas = alphas
|
||||||
@@ -79,13 +82,13 @@ def get_snr_scale(timesteps, noise_scheduler):
|
|||||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||||
scale = snr_t / (snr_t + 1)
|
scale = snr_t / (snr_t + 1)
|
||||||
# # show debug info
|
# # show debug info
|
||||||
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
||||||
return scale
|
return scale
|
||||||
|
|
||||||
|
|
||||||
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
||||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||||
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||||
loss = loss + loss / scale * v_pred_like_loss
|
loss = loss + loss / scale * v_pred_like_loss
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@@ -268,7 +271,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
|||||||
tokens.append(text_token)
|
tokens.append(text_token)
|
||||||
weights.append(text_weight)
|
weights.append(text_weight)
|
||||||
if truncated:
|
if truncated:
|
||||||
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||||
return tokens, weights
|
return tokens, weights
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ from pathlib import Path
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from library.utils import fire_in_thread
|
from library.utils import fire_in_thread
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
||||||
api = HfApi(
|
api = HfApi(
|
||||||
@@ -33,9 +36,9 @@ def upload(
|
|||||||
try:
|
try:
|
||||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||||
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
||||||
print("===========================================")
|
logger.error("===========================================")
|
||||||
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
||||||
print("===========================================")
|
logger.error("===========================================")
|
||||||
|
|
||||||
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
||||||
|
|
||||||
@@ -56,9 +59,9 @@ def upload(
|
|||||||
path_in_repo=path_in_repo,
|
path_in_repo=path_in_repo,
|
||||||
)
|
)
|
||||||
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
||||||
print("===========================================")
|
logger.error("===========================================")
|
||||||
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
||||||
print("===========================================")
|
logger.error("===========================================")
|
||||||
|
|
||||||
if args.async_upload and not force_sync_upload:
|
if args.async_upload and not force_sync_upload:
|
||||||
fire_in_thread(uploader)
|
fire_in_thread(uploader)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
|||||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||||
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||||
return module.to("xpu")
|
return module.to("xpu")
|
||||||
|
|
||||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
|||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers.utils import PIL_INTERPOLATION
|
from diffusers.utils import PIL_INTERPOLATION
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -626,7 +625,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||||
|
|
||||||
if height % 8 != 0 or width % 8 != 0:
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
print(height, width)
|
logger.info(f'{height} {width}')
|
||||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||||
|
|
||||||
if (callback_steps is None) or (
|
if (callback_steps is None) or (
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
|||||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||||
NUM_TRAIN_TIMESTEPS = 1000
|
NUM_TRAIN_TIMESTEPS = 1000
|
||||||
@@ -944,7 +948,7 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
for k, v in new_state_dict.items():
|
for k, v in new_state_dict.items():
|
||||||
for weight_name in weights_to_convert:
|
for weight_name in weights_to_convert:
|
||||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||||
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
||||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
@@ -1002,7 +1006,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
|||||||
|
|
||||||
unet = UNet2DConditionModel(**unet_config).to(device)
|
unet = UNet2DConditionModel(**unet_config).to(device)
|
||||||
info = unet.load_state_dict(converted_unet_checkpoint)
|
info = unet.load_state_dict(converted_unet_checkpoint)
|
||||||
print("loading u-net:", info)
|
logger.info(f"loading u-net: {info}")
|
||||||
|
|
||||||
# Convert the VAE model.
|
# Convert the VAE model.
|
||||||
vae_config = create_vae_diffusers_config()
|
vae_config = create_vae_diffusers_config()
|
||||||
@@ -1010,7 +1014,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
|||||||
|
|
||||||
vae = AutoencoderKL(**vae_config).to(device)
|
vae = AutoencoderKL(**vae_config).to(device)
|
||||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||||
print("loading vae:", info)
|
logger.info(f"loading vae: {info}")
|
||||||
|
|
||||||
# convert text_model
|
# convert text_model
|
||||||
if v2:
|
if v2:
|
||||||
@@ -1044,7 +1048,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
|||||||
# logging.set_verbosity_error() # don't show annoying warning
|
# logging.set_verbosity_error() # don't show annoying warning
|
||||||
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||||
# logging.set_verbosity_warning()
|
# logging.set_verbosity_warning()
|
||||||
# print(f"config: {text_model.config}")
|
# logger.info(f"config: {text_model.config}")
|
||||||
cfg = CLIPTextConfig(
|
cfg = CLIPTextConfig(
|
||||||
vocab_size=49408,
|
vocab_size=49408,
|
||||||
hidden_size=768,
|
hidden_size=768,
|
||||||
@@ -1067,7 +1071,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
|||||||
)
|
)
|
||||||
text_model = CLIPTextModel._from_config(cfg)
|
text_model = CLIPTextModel._from_config(cfg)
|
||||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||||
print("loading text encoder:", info)
|
logger.info(f"loading text encoder: {info}")
|
||||||
|
|
||||||
return text_model, vae, unet
|
return text_model, vae, unet
|
||||||
|
|
||||||
@@ -1142,7 +1146,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
|||||||
|
|
||||||
# 最後の層などを捏造するか
|
# 最後の層などを捏造するか
|
||||||
if make_dummy_weights:
|
if make_dummy_weights:
|
||||||
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
|
||||||
keys = list(new_sd.keys())
|
keys = list(new_sd.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("transformer.resblocks.22."):
|
if key.startswith("transformer.resblocks.22."):
|
||||||
@@ -1261,14 +1265,14 @@ VAE_PREFIX = "first_stage_model."
|
|||||||
|
|
||||||
|
|
||||||
def load_vae(vae_id, dtype):
|
def load_vae(vae_id, dtype):
|
||||||
print(f"load VAE: {vae_id}")
|
logger.info(f"load VAE: {vae_id}")
|
||||||
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||||
# Diffusers local/remote
|
# Diffusers local/remote
|
||||||
try:
|
try:
|
||||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
||||||
except EnvironmentError as e:
|
except EnvironmentError as e:
|
||||||
print(f"exception occurs in loading vae: {e}")
|
logger.error(f"exception occurs in loading vae: {e}")
|
||||||
print("retry with subfolder='vae'")
|
logger.error("retry with subfolder='vae'")
|
||||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
@@ -1340,13 +1344,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
resos = make_bucket_resolutions((512, 768))
|
resos = make_bucket_resolutions((512, 768))
|
||||||
print(len(resos))
|
logger.info(f"{len(resos)}")
|
||||||
print(resos)
|
logger.info(f"{resos}")
|
||||||
aspect_ratios = [w / h for w, h in resos]
|
aspect_ratios = [w / h for w, h in resos]
|
||||||
print(aspect_ratios)
|
logger.info(f"{aspect_ratios}")
|
||||||
|
|
||||||
ars = set()
|
ars = set()
|
||||||
for ar in aspect_ratios:
|
for ar in aspect_ratios:
|
||||||
if ar in ars:
|
if ar in ars:
|
||||||
print("error! duplicate ar:", ar)
|
logger.error(f"error! duplicate ar: {ar}")
|
||||||
ars.add(ar)
|
ars.add(ar)
|
||||||
|
|||||||
@@ -113,6 +113,10 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
||||||
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
||||||
@@ -1380,7 +1384,7 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert sample_size is not None, "sample_size must be specified"
|
assert sample_size is not None, "sample_size must be specified"
|
||||||
print(
|
logger.info(
|
||||||
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
|
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1514,7 +1518,7 @@ class UNet2DConditionModel(nn.Module):
|
|||||||
def set_gradient_checkpointing(self, value=False):
|
def set_gradient_checkpointing(self, value=False):
|
||||||
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||||
for module in modules:
|
for module in modules:
|
||||||
print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
|
logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
@@ -1709,14 +1713,14 @@ class InferUNet2DConditionModel:
|
|||||||
|
|
||||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||||
if ds_depth_1 is None:
|
if ds_depth_1 is None:
|
||||||
print("Deep Shrink is disabled.")
|
logger.info("Deep Shrink is disabled.")
|
||||||
self.ds_depth_1 = None
|
self.ds_depth_1 = None
|
||||||
self.ds_timesteps_1 = None
|
self.ds_timesteps_1 = None
|
||||||
self.ds_depth_2 = None
|
self.ds_depth_2 = None
|
||||||
self.ds_timesteps_2 = None
|
self.ds_timesteps_2 = None
|
||||||
self.ds_ratio = None
|
self.ds_ratio = None
|
||||||
else:
|
else:
|
||||||
print(
|
logger.info(
|
||||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||||
)
|
)
|
||||||
self.ds_depth_1 = ds_depth_1
|
self.ds_depth_1 = ds_depth_1
|
||||||
|
|||||||
@@ -5,6 +5,10 @@ from io import BytesIO
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
import safetensors
|
import safetensors
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
# Metadata Example
|
# Metadata Example
|
||||||
@@ -231,7 +235,7 @@ def build_metadata(
|
|||||||
# # assert all values are filled
|
# # assert all values are filled
|
||||||
# assert all([v is not None for v in metadata.values()]), metadata
|
# assert all([v is not None for v in metadata.values()]), metadata
|
||||||
if not all([v is not None for v in metadata.values()]):
|
if not all([v is not None for v in metadata.values()]):
|
||||||
print(f"Internal error: some metadata values are None: {metadata}")
|
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,10 @@ from typing import List
|
|||||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
||||||
from library import model_util
|
from library import model_util
|
||||||
from library import sdxl_original_unet
|
from library import sdxl_original_unet
|
||||||
|
from .utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
VAE_SCALE_FACTOR = 0.13025
|
VAE_SCALE_FACTOR = 0.13025
|
||||||
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
||||||
@@ -131,7 +134,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
|
|
||||||
# temporary workaround for text_projection.weight.weight for Playground-v2
|
# temporary workaround for text_projection.weight.weight for Playground-v2
|
||||||
if "text_projection.weight.weight" in new_sd:
|
if "text_projection.weight.weight" in new_sd:
|
||||||
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
||||||
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
||||||
del new_sd["text_projection.weight.weight"]
|
del new_sd["text_projection.weight.weight"]
|
||||||
|
|
||||||
@@ -186,20 +189,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
|||||||
checkpoint = None
|
checkpoint = None
|
||||||
|
|
||||||
# U-Net
|
# U-Net
|
||||||
print("building U-Net")
|
logger.info("building U-Net")
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||||
|
|
||||||
print("loading U-Net from checkpoint")
|
logger.info("loading U-Net from checkpoint")
|
||||||
unet_sd = {}
|
unet_sd = {}
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
if k.startswith("model.diffusion_model."):
|
if k.startswith("model.diffusion_model."):
|
||||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||||
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
||||||
print("U-Net: ", info)
|
logger.info(f"U-Net: {info}")
|
||||||
|
|
||||||
# Text Encoders
|
# Text Encoders
|
||||||
print("building text encoders")
|
logger.info("building text encoders")
|
||||||
|
|
||||||
# Text Encoder 1 is same to Stability AI's SDXL
|
# Text Encoder 1 is same to Stability AI's SDXL
|
||||||
text_model1_cfg = CLIPTextConfig(
|
text_model1_cfg = CLIPTextConfig(
|
||||||
@@ -252,7 +255,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
|||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
||||||
|
|
||||||
print("loading text encoders from checkpoint")
|
logger.info("loading text encoders from checkpoint")
|
||||||
te1_sd = {}
|
te1_sd = {}
|
||||||
te2_sd = {}
|
te2_sd = {}
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
@@ -266,22 +269,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
|||||||
te1_sd.pop("text_model.embeddings.position_ids")
|
te1_sd.pop("text_model.embeddings.position_ids")
|
||||||
|
|
||||||
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
||||||
print("text encoder 1:", info1)
|
logger.info(f"text encoder 1: {info1}")
|
||||||
|
|
||||||
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||||
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
|
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
|
||||||
print("text encoder 2:", info2)
|
logger.info(f"text encoder 2: {info2}")
|
||||||
|
|
||||||
# prepare vae
|
# prepare vae
|
||||||
print("building VAE")
|
logger.info("building VAE")
|
||||||
vae_config = model_util.create_vae_diffusers_config()
|
vae_config = model_util.create_vae_diffusers_config()
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
|
|
||||||
print("loading VAE from checkpoint")
|
logger.info("loading VAE from checkpoint")
|
||||||
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
||||||
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
||||||
print("VAE:", info)
|
logger.info(f"VAE: {info}")
|
||||||
|
|
||||||
ckpt_info = (epoch, global_step) if epoch is not None else None
|
ckpt_info = (epoch, global_step) if epoch is not None else None
|
||||||
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
||||||
|
|||||||
@@ -30,7 +30,10 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from .utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
IN_CHANNELS: int = 4
|
IN_CHANNELS: int = 4
|
||||||
OUT_CHANNELS: int = 4
|
OUT_CHANNELS: int = 4
|
||||||
@@ -332,7 +335,7 @@ class ResnetBlock2D(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
# print("ResnetBlock2D: gradient_checkpointing")
|
# logger.info("ResnetBlock2D: gradient_checkpointing")
|
||||||
|
|
||||||
def create_custom_forward(func):
|
def create_custom_forward(func):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -366,7 +369,7 @@ class Downsample2D(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
# print("Downsample2D: gradient_checkpointing")
|
# logger.info("Downsample2D: gradient_checkpointing")
|
||||||
|
|
||||||
def create_custom_forward(func):
|
def create_custom_forward(func):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -653,7 +656,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states, context=None, timestep=None):
|
def forward(self, hidden_states, context=None, timestep=None):
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
# print("BasicTransformerBlock: checkpointing")
|
# logger.info("BasicTransformerBlock: checkpointing")
|
||||||
|
|
||||||
def create_custom_forward(func):
|
def create_custom_forward(func):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -796,7 +799,7 @@ class Upsample2D(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states, output_size=None):
|
def forward(self, hidden_states, output_size=None):
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
# print("Upsample2D: gradient_checkpointing")
|
# logger.info("Upsample2D: gradient_checkpointing")
|
||||||
|
|
||||||
def create_custom_forward(func):
|
def create_custom_forward(func):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -1046,7 +1049,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
for block in blocks:
|
for block in blocks:
|
||||||
for module in block:
|
for module in block:
|
||||||
if hasattr(module, "set_use_memory_efficient_attention"):
|
if hasattr(module, "set_use_memory_efficient_attention"):
|
||||||
# print(module.__class__.__name__)
|
# logger.info(module.__class__.__name__)
|
||||||
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||||
|
|
||||||
def set_use_sdpa(self, sdpa: bool) -> None:
|
def set_use_sdpa(self, sdpa: bool) -> None:
|
||||||
@@ -1061,7 +1064,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
for block in blocks:
|
for block in blocks:
|
||||||
for module in block.modules():
|
for module in block.modules():
|
||||||
if hasattr(module, "gradient_checkpointing"):
|
if hasattr(module, "gradient_checkpointing"):
|
||||||
# print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
|
# logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
@@ -1083,7 +1086,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
|||||||
def call_module(module, h, emb, context):
|
def call_module(module, h, emb, context):
|
||||||
x = h
|
x = h
|
||||||
for layer in module:
|
for layer in module:
|
||||||
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
# logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
||||||
if isinstance(layer, ResnetBlock2D):
|
if isinstance(layer, ResnetBlock2D):
|
||||||
x = layer(x, emb)
|
x = layer(x, emb)
|
||||||
elif isinstance(layer, Transformer2DModel):
|
elif isinstance(layer, Transformer2DModel):
|
||||||
@@ -1135,14 +1138,14 @@ class InferSdxlUNet2DConditionModel:
|
|||||||
|
|
||||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||||
if ds_depth_1 is None:
|
if ds_depth_1 is None:
|
||||||
print("Deep Shrink is disabled.")
|
logger.info("Deep Shrink is disabled.")
|
||||||
self.ds_depth_1 = None
|
self.ds_depth_1 = None
|
||||||
self.ds_timesteps_1 = None
|
self.ds_timesteps_1 = None
|
||||||
self.ds_depth_2 = None
|
self.ds_depth_2 = None
|
||||||
self.ds_timesteps_2 = None
|
self.ds_timesteps_2 = None
|
||||||
self.ds_ratio = None
|
self.ds_ratio = None
|
||||||
else:
|
else:
|
||||||
print(
|
logger.info(
|
||||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||||
)
|
)
|
||||||
self.ds_depth_1 = ds_depth_1
|
self.ds_depth_1 = ds_depth_1
|
||||||
@@ -1229,7 +1232,7 @@ class InferSdxlUNet2DConditionModel:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import time
|
import time
|
||||||
|
|
||||||
print("create unet")
|
logger.info("create unet")
|
||||||
unet = SdxlUNet2DConditionModel()
|
unet = SdxlUNet2DConditionModel()
|
||||||
|
|
||||||
unet.to("cuda")
|
unet.to("cuda")
|
||||||
@@ -1238,7 +1241,7 @@ if __name__ == "__main__":
|
|||||||
unet.train()
|
unet.train()
|
||||||
|
|
||||||
# 使用メモリ量確認用の疑似学習ループ
|
# 使用メモリ量確認用の疑似学習ループ
|
||||||
print("preparing optimizer")
|
logger.info("preparing optimizer")
|
||||||
|
|
||||||
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
||||||
|
|
||||||
@@ -1253,12 +1256,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||||
|
|
||||||
print("start training")
|
logger.info("start training")
|
||||||
steps = 10
|
steps = 10
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
||||||
for step in range(steps):
|
for step in range(steps):
|
||||||
print(f"step {step}")
|
logger.info(f"step {step}")
|
||||||
if step == 1:
|
if step == 1:
|
||||||
time_start = time.perf_counter()
|
time_start = time.perf_counter()
|
||||||
|
|
||||||
@@ -1278,4 +1281,4 @@ if __name__ == "__main__":
|
|||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
time_end = time.perf_counter()
|
time_end = time.perf_counter()
|
||||||
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ from tqdm import tqdm
|
|||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||||
|
from .utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
@@ -21,7 +25,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||||
for pi in range(accelerator.state.num_processes):
|
for pi in range(accelerator.state.num_processes):
|
||||||
if pi == accelerator.state.local_process_index:
|
if pi == accelerator.state.local_process_index:
|
||||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
|
|
||||||
(
|
(
|
||||||
load_stable_diffusion_format,
|
load_stable_diffusion_format,
|
||||||
@@ -62,7 +66,7 @@ def _load_target_model(
|
|||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
|
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||||
(
|
(
|
||||||
text_encoder1,
|
text_encoder1,
|
||||||
text_encoder2,
|
text_encoder2,
|
||||||
@@ -76,7 +80,7 @@ def _load_target_model(
|
|||||||
from diffusers import StableDiffusionXLPipeline
|
from diffusers import StableDiffusionXLPipeline
|
||||||
|
|
||||||
variant = "fp16" if weight_dtype == torch.float16 else None
|
variant = "fp16" if weight_dtype == torch.float16 else None
|
||||||
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||||
@@ -84,12 +88,12 @@ def _load_target_model(
|
|||||||
)
|
)
|
||||||
except EnvironmentError as ex:
|
except EnvironmentError as ex:
|
||||||
if variant is not None:
|
if variant is not None:
|
||||||
print("try to load fp32 model")
|
logger.info("try to load fp32 model")
|
||||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
||||||
else:
|
else:
|
||||||
raise ex
|
raise ex
|
||||||
except EnvironmentError as ex:
|
except EnvironmentError as ex:
|
||||||
print(
|
logger.error(
|
||||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
||||||
)
|
)
|
||||||
raise ex
|
raise ex
|
||||||
@@ -112,7 +116,7 @@ def _load_target_model(
|
|||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
||||||
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
|
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
|
||||||
print("U-Net converted to original U-Net")
|
logger.info("U-Net converted to original U-Net")
|
||||||
|
|
||||||
logit_scale = None
|
logit_scale = None
|
||||||
ckpt_info = None
|
ckpt_info = None
|
||||||
@@ -120,13 +124,13 @@ def _load_target_model(
|
|||||||
# VAEを読み込む
|
# VAEを読み込む
|
||||||
if vae_path is not None:
|
if vae_path is not None:
|
||||||
vae = model_util.load_vae(vae_path, weight_dtype)
|
vae = model_util.load_vae(vae_path, weight_dtype)
|
||||||
print("additional VAE loaded")
|
logger.info("additional VAE loaded")
|
||||||
|
|
||||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizers(args: argparse.Namespace):
|
def load_tokenizers(args: argparse.Namespace):
|
||||||
print("prepare tokenizers")
|
logger.info("prepare tokenizers")
|
||||||
|
|
||||||
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
||||||
tokeniers = []
|
tokeniers = []
|
||||||
@@ -135,14 +139,14 @@ def load_tokenizers(args: argparse.Namespace):
|
|||||||
if args.tokenizer_cache_dir:
|
if args.tokenizer_cache_dir:
|
||||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||||
if os.path.exists(local_tokenizer_path):
|
if os.path.exists(local_tokenizer_path):
|
||||||
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||||
|
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||||
|
|
||||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||||
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||||
tokenizer.save_pretrained(local_tokenizer_path)
|
tokenizer.save_pretrained(local_tokenizer_path)
|
||||||
|
|
||||||
if i == 1:
|
if i == 1:
|
||||||
@@ -151,7 +155,7 @@ def load_tokenizers(args: argparse.Namespace):
|
|||||||
tokeniers.append(tokenizer)
|
tokeniers.append(tokenizer)
|
||||||
|
|
||||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||||
print(f"update token length: {args.max_token_length}")
|
logger.info(f"update token length: {args.max_token_length}")
|
||||||
|
|
||||||
return tokeniers
|
return tokeniers
|
||||||
|
|
||||||
@@ -332,23 +336,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
|||||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||||
|
|
||||||
if args.clip_skip is not None:
|
if args.clip_skip is not None:
|
||||||
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||||
|
|
||||||
# if args.multires_noise_iterations:
|
# if args.multires_noise_iterations:
|
||||||
# print(
|
# logger.info(
|
||||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
||||||
# )
|
# )
|
||||||
# else:
|
# else:
|
||||||
# if args.noise_offset is None:
|
# if args.noise_offset is None:
|
||||||
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
||||||
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
||||||
# print(
|
# logger.info(
|
||||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
||||||
# )
|
# )
|
||||||
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||||
@@ -357,7 +361,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
|||||||
if supportTextEncoderCaching:
|
if supportTextEncoderCaching:
|
||||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||||
args.cache_text_encoder_outputs = True
|
args.cache_text_encoder_outputs = True
|
||||||
print(
|
logger.warning(
|
||||||
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
||||||
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -26,7 +26,10 @@ from diffusers.models.modeling_utils import ModelMixin
|
|||||||
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||||
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
|
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
|
||||||
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
||||||
|
from .utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def slice_h(x, num_slices):
|
def slice_h(x, num_slices):
|
||||||
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
||||||
@@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
|
|||||||
# sliced_tensor = torch.chunk(x, num_div, dim=1)
|
# sliced_tensor = torch.chunk(x, num_div, dim=1)
|
||||||
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
|
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
|
||||||
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
|
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
|
||||||
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
# logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
||||||
# normed_tensor = []
|
# normed_tensor = []
|
||||||
# for i in range(num_div):
|
# for i in range(num_div):
|
||||||
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
|
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
|
||||||
@@ -243,7 +246,7 @@ class SlicingEncoder(nn.Module):
|
|||||||
|
|
||||||
self.num_slices = num_slices
|
self.num_slices = num_slices
|
||||||
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
|
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
|
||||||
# print(f"initial divisor: {div}")
|
# logger.info(f"initial divisor: {div}")
|
||||||
if div >= 2:
|
if div >= 2:
|
||||||
div = int(div)
|
div = int(div)
|
||||||
for resnet in self.mid_block.resnets:
|
for resnet in self.mid_block.resnets:
|
||||||
@@ -253,11 +256,11 @@ class SlicingEncoder(nn.Module):
|
|||||||
for i, down_block in enumerate(self.down_blocks[::-1]):
|
for i, down_block in enumerate(self.down_blocks[::-1]):
|
||||||
if div >= 2:
|
if div >= 2:
|
||||||
div = int(div)
|
div = int(div)
|
||||||
# print(f"down block: {i} divisor: {div}")
|
# logger.info(f"down block: {i} divisor: {div}")
|
||||||
for resnet in down_block.resnets:
|
for resnet in down_block.resnets:
|
||||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||||
if down_block.downsamplers is not None:
|
if down_block.downsamplers is not None:
|
||||||
# print("has downsample")
|
# logger.info("has downsample")
|
||||||
for downsample in down_block.downsamplers:
|
for downsample in down_block.downsamplers:
|
||||||
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
|
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
|
||||||
div *= 2
|
div *= 2
|
||||||
@@ -307,7 +310,7 @@ class SlicingEncoder(nn.Module):
|
|||||||
def downsample_forward(self, _self, num_slices, hidden_states):
|
def downsample_forward(self, _self, num_slices, hidden_states):
|
||||||
assert hidden_states.shape[1] == _self.channels
|
assert hidden_states.shape[1] == _self.channels
|
||||||
assert _self.use_conv and _self.padding == 0
|
assert _self.use_conv and _self.padding == 0
|
||||||
print("downsample forward", num_slices, hidden_states.shape)
|
logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
|
||||||
|
|
||||||
org_device = hidden_states.device
|
org_device = hidden_states.device
|
||||||
cpu_device = torch.device("cpu")
|
cpu_device = torch.device("cpu")
|
||||||
@@ -350,7 +353,7 @@ class SlicingEncoder(nn.Module):
|
|||||||
hidden_states = torch.cat([hidden_states, x], dim=2)
|
hidden_states = torch.cat([hidden_states, x], dim=2)
|
||||||
|
|
||||||
hidden_states = hidden_states.to(org_device)
|
hidden_states = hidden_states.to(org_device)
|
||||||
# print("downsample forward done", hidden_states.shape)
|
# logger.info(f"downsample forward done {hidden_states.shape}")
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -426,7 +429,7 @@ class SlicingDecoder(nn.Module):
|
|||||||
|
|
||||||
self.num_slices = num_slices
|
self.num_slices = num_slices
|
||||||
div = num_slices / (2 ** (len(self.up_blocks) - 1))
|
div = num_slices / (2 ** (len(self.up_blocks) - 1))
|
||||||
print(f"initial divisor: {div}")
|
logger.info(f"initial divisor: {div}")
|
||||||
if div >= 2:
|
if div >= 2:
|
||||||
div = int(div)
|
div = int(div)
|
||||||
for resnet in self.mid_block.resnets:
|
for resnet in self.mid_block.resnets:
|
||||||
@@ -436,11 +439,11 @@ class SlicingDecoder(nn.Module):
|
|||||||
for i, up_block in enumerate(self.up_blocks):
|
for i, up_block in enumerate(self.up_blocks):
|
||||||
if div >= 2:
|
if div >= 2:
|
||||||
div = int(div)
|
div = int(div)
|
||||||
# print(f"up block: {i} divisor: {div}")
|
# logger.info(f"up block: {i} divisor: {div}")
|
||||||
for resnet in up_block.resnets:
|
for resnet in up_block.resnets:
|
||||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||||
if up_block.upsamplers is not None:
|
if up_block.upsamplers is not None:
|
||||||
# print("has upsample")
|
# logger.info("has upsample")
|
||||||
for upsample in up_block.upsamplers:
|
for upsample in up_block.upsamplers:
|
||||||
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
|
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
|
||||||
div *= 2
|
div *= 2
|
||||||
@@ -528,7 +531,7 @@ class SlicingDecoder(nn.Module):
|
|||||||
del x
|
del x
|
||||||
|
|
||||||
hidden_states = torch.cat(sliced, dim=2)
|
hidden_states = torch.cat(sliced, dim=2)
|
||||||
# print("us hidden_states", hidden_states.shape)
|
# logger.info(f"us hidden_states {hidden_states.shape}")
|
||||||
del sliced
|
del sliced
|
||||||
|
|
||||||
hidden_states = hidden_states.to(org_device)
|
hidden_states = hidden_states.to(org_device)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,23 @@
|
|||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import *
|
from typing import *
|
||||||
|
|
||||||
|
|
||||||
def fire_in_thread(f, *args, **kwargs):
|
def fire_in_thread(f, *args, **kwargs):
|
||||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(log_level=logging.INFO):
|
||||||
|
if logging.root.handlers: # Already configured
|
||||||
|
return
|
||||||
|
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
|
handler = RichHandler()
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="%(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logging.root.setLevel(log_level)
|
||||||
|
logging.root.addHandler(handler)
|
||||||
|
|||||||
@@ -2,10 +2,13 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def main(file):
|
def main(file):
|
||||||
print(f"loading: {file}")
|
logger.info(f"loading: {file}")
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
sd = load_file(file)
|
sd = load_file(file)
|
||||||
else:
|
else:
|
||||||
@@ -17,16 +20,16 @@ def main(file):
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
if "lora_up" in key or "lora_down" in key:
|
if "lora_up" in key or "lora_down" in key:
|
||||||
values.append((key, sd[key]))
|
values.append((key, sd[key]))
|
||||||
print(f"number of LoRA modules: {len(values)}")
|
logger.info(f"number of LoRA modules: {len(values)}")
|
||||||
|
|
||||||
if args.show_all_keys:
|
if args.show_all_keys:
|
||||||
for key in [k for k in keys if k not in values]:
|
for key in [k for k in keys if k not in values]:
|
||||||
values.append((key, sd[key]))
|
values.append((key, sd[key]))
|
||||||
print(f"number of all modules: {len(values)}")
|
logger.info(f"number of all modules: {len(values)}")
|
||||||
|
|
||||||
for key, value in values:
|
for key, value in values:
|
||||||
value = value.to(torch.float32)
|
value = value.to(torch.float32)
|
||||||
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
logger.info(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ import os
|
|||||||
from typing import Optional, List, Type
|
from typing import Optional, List, Type
|
||||||
import torch
|
import torch
|
||||||
from library import sdxl_original_unet
|
from library import sdxl_original_unet
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||||
SKIP_INPUT_BLOCKS = False
|
SKIP_INPUT_BLOCKS = False
|
||||||
@@ -125,7 +128,7 @@ class LLLiteModule(torch.nn.Module):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||||
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
# logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||||
cx = self.conditioning1(cond_image)
|
cx = self.conditioning1(cond_image)
|
||||||
if not self.is_conv2d:
|
if not self.is_conv2d:
|
||||||
# reshape / b,c,h,w -> b,h*w,c
|
# reshape / b,c,h,w -> b,h*w,c
|
||||||
@@ -155,7 +158,7 @@ class LLLiteModule(torch.nn.Module):
|
|||||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||||
if self.use_zeros_for_batch_uncond:
|
if self.use_zeros_for_batch_uncond:
|
||||||
cx[0::2] = 0.0 # uncond is zero
|
cx[0::2] = 0.0 # uncond is zero
|
||||||
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
# logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
||||||
|
|
||||||
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
||||||
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
||||||
@@ -286,7 +289,7 @@ class ControlNetLLLite(torch.nn.Module):
|
|||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
|
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
|
||||||
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x # dummy
|
return x # dummy
|
||||||
@@ -319,7 +322,7 @@ class ControlNetLLLite(torch.nn.Module):
|
|||||||
return info
|
return info
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
print("applying LLLite for U-Net...")
|
logger.info("applying LLLite for U-Net...")
|
||||||
for module in self.unet_modules:
|
for module in self.unet_modules:
|
||||||
module.apply_to()
|
module.apply_to()
|
||||||
self.add_module(module.lllite_name, module)
|
self.add_module(module.lllite_name, module)
|
||||||
@@ -374,19 +377,19 @@ if __name__ == "__main__":
|
|||||||
# sdxl_original_unet.USE_REENTRANT = False
|
# sdxl_original_unet.USE_REENTRANT = False
|
||||||
|
|
||||||
# test shape etc
|
# test shape etc
|
||||||
print("create unet")
|
logger.info("create unet")
|
||||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||||
unet.to("cuda").to(torch.float16)
|
unet.to("cuda").to(torch.float16)
|
||||||
|
|
||||||
print("create ControlNet-LLLite")
|
logger.info("create ControlNet-LLLite")
|
||||||
control_net = ControlNetLLLite(unet, 32, 64)
|
control_net = ControlNetLLLite(unet, 32, 64)
|
||||||
control_net.apply_to()
|
control_net.apply_to()
|
||||||
control_net.to("cuda")
|
control_net.to("cuda")
|
||||||
|
|
||||||
print(control_net)
|
logger.info(control_net)
|
||||||
|
|
||||||
# print number of parameters
|
# logger.info number of parameters
|
||||||
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
|
logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
input()
|
input()
|
||||||
|
|
||||||
@@ -398,12 +401,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# # visualize
|
# # visualize
|
||||||
# import torchviz
|
# import torchviz
|
||||||
# print("run visualize")
|
# logger.info("run visualize")
|
||||||
# controlnet.set_control(conditioning_image)
|
# controlnet.set_control(conditioning_image)
|
||||||
# output = unet(x, t, ctx, y)
|
# output = unet(x, t, ctx, y)
|
||||||
# print("make_dot")
|
# logger.info("make_dot")
|
||||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||||
# print("render")
|
# logger.info("render")
|
||||||
# image.format = "svg" # "png"
|
# image.format = "svg" # "png"
|
||||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||||
# input()
|
# input()
|
||||||
@@ -414,12 +417,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||||
|
|
||||||
print("start training")
|
logger.info("start training")
|
||||||
steps = 10
|
steps = 10
|
||||||
|
|
||||||
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
|
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
|
||||||
for step in range(steps):
|
for step in range(steps):
|
||||||
print(f"step {step}")
|
logger.info(f"step {step}")
|
||||||
|
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
||||||
@@ -439,7 +442,7 @@ if __name__ == "__main__":
|
|||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
print(sample_param)
|
logger.info(f"{sample_param}")
|
||||||
|
|
||||||
# from safetensors.torch import save_file
|
# from safetensors.torch import save_file
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ import re
|
|||||||
from typing import Optional, List, Type
|
from typing import Optional, List, Type
|
||||||
import torch
|
import torch
|
||||||
from library import sdxl_original_unet
|
from library import sdxl_original_unet
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||||
SKIP_INPUT_BLOCKS = False
|
SKIP_INPUT_BLOCKS = False
|
||||||
@@ -270,7 +273,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
|||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
self.lllite_modules = apply_to_modules(self, target_modules)
|
self.lllite_modules = apply_to_modules(self, target_modules)
|
||||||
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
||||||
|
|
||||||
# def prepare_optimizer_params(self):
|
# def prepare_optimizer_params(self):
|
||||||
def prepare_params(self):
|
def prepare_params(self):
|
||||||
@@ -281,8 +284,8 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
|||||||
train_params.append(p)
|
train_params.append(p)
|
||||||
else:
|
else:
|
||||||
non_train_params.append(p)
|
non_train_params.append(p)
|
||||||
print(f"count of trainable parameters: {len(train_params)}")
|
logger.info(f"count of trainable parameters: {len(train_params)}")
|
||||||
print(f"count of non-trainable parameters: {len(non_train_params)}")
|
logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
|
||||||
|
|
||||||
for p in non_train_params:
|
for p in non_train_params:
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
@@ -388,7 +391,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
|||||||
matches = pattern.findall(module_name)
|
matches = pattern.findall(module_name)
|
||||||
if matches is not None:
|
if matches is not None:
|
||||||
for m in matches:
|
for m in matches:
|
||||||
print(module_name, m)
|
logger.info(f"{module_name} {m}")
|
||||||
module_name = module_name.replace(m, m.replace("_", "@"))
|
module_name = module_name.replace(m, m.replace("_", "@"))
|
||||||
module_name = module_name.replace("_", ".")
|
module_name = module_name.replace("_", ".")
|
||||||
module_name = module_name.replace("@", "_")
|
module_name = module_name.replace("@", "_")
|
||||||
@@ -407,7 +410,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
|||||||
|
|
||||||
|
|
||||||
def replace_unet_linear_and_conv2d():
|
def replace_unet_linear_and_conv2d():
|
||||||
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
||||||
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
|
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
|
||||||
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
|
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
|
||||||
|
|
||||||
@@ -419,10 +422,10 @@ if __name__ == "__main__":
|
|||||||
replace_unet_linear_and_conv2d()
|
replace_unet_linear_and_conv2d()
|
||||||
|
|
||||||
# test shape etc
|
# test shape etc
|
||||||
print("create unet")
|
logger.info("create unet")
|
||||||
unet = SdxlUNet2DConditionModelControlNetLLLite()
|
unet = SdxlUNet2DConditionModelControlNetLLLite()
|
||||||
|
|
||||||
print("enable ControlNet-LLLite")
|
logger.info("enable ControlNet-LLLite")
|
||||||
unet.apply_lllite(32, 64, None, False, 1.0)
|
unet.apply_lllite(32, 64, None, False, 1.0)
|
||||||
unet.to("cuda") # .to(torch.float16)
|
unet.to("cuda") # .to(torch.float16)
|
||||||
|
|
||||||
@@ -439,14 +442,14 @@ if __name__ == "__main__":
|
|||||||
# unet_sd[converted_key] = model_sd[key]
|
# unet_sd[converted_key] = model_sd[key]
|
||||||
|
|
||||||
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
|
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
|
||||||
# print(info)
|
# logger.info(info)
|
||||||
|
|
||||||
# print(unet)
|
# logger.info(unet)
|
||||||
|
|
||||||
# print number of parameters
|
# logger.info number of parameters
|
||||||
params = unet.prepare_params()
|
params = unet.prepare_params()
|
||||||
print("number of parameters", sum(p.numel() for p in params))
|
logger.info(f"number of parameters {sum(p.numel() for p in params)}")
|
||||||
# print("type any key to continue")
|
# logger.info("type any key to continue")
|
||||||
# input()
|
# input()
|
||||||
|
|
||||||
unet.set_use_memory_efficient_attention(True, False)
|
unet.set_use_memory_efficient_attention(True, False)
|
||||||
@@ -455,12 +458,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# # visualize
|
# # visualize
|
||||||
# import torchviz
|
# import torchviz
|
||||||
# print("run visualize")
|
# logger.info("run visualize")
|
||||||
# controlnet.set_control(conditioning_image)
|
# controlnet.set_control(conditioning_image)
|
||||||
# output = unet(x, t, ctx, y)
|
# output = unet(x, t, ctx, y)
|
||||||
# print("make_dot")
|
# logger.info("make_dot")
|
||||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||||
# print("render")
|
# logger.info("render")
|
||||||
# image.format = "svg" # "png"
|
# image.format = "svg" # "png"
|
||||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||||
# input()
|
# input()
|
||||||
@@ -471,13 +474,13 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||||
|
|
||||||
print("start training")
|
logger.info("start training")
|
||||||
steps = 10
|
steps = 10
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
||||||
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
|
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
|
||||||
for step in range(steps):
|
for step in range(steps):
|
||||||
print(f"step {step}")
|
logger.info(f"step {step}")
|
||||||
|
|
||||||
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
||||||
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
||||||
@@ -494,9 +497,9 @@ if __name__ == "__main__":
|
|||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
print(sample_param)
|
logger.info(sample_param)
|
||||||
|
|
||||||
# from safetensors.torch import save_file
|
# from safetensors.torch import save_file
|
||||||
|
|
||||||
# print("save weights")
|
# logger.info("save weights")
|
||||||
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
|
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
|
||||||
|
|||||||
@@ -15,7 +15,10 @@ import random
|
|||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DyLoRAModule(torch.nn.Module):
|
class DyLoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -223,7 +226,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||||||
elif "lora_down" in key:
|
elif "lora_down" in key:
|
||||||
dim = value.size()[0]
|
dim = value.size()[0]
|
||||||
modules_dim[lora_name] = dim
|
modules_dim[lora_name] = dim
|
||||||
# print(lora_name, value.size(), dim)
|
# logger.info(f"{lora_name} {value.size()} {dim}")
|
||||||
|
|
||||||
# support old LoRA without alpha
|
# support old LoRA without alpha
|
||||||
for key in modules_dim.keys():
|
for key in modules_dim.keys():
|
||||||
@@ -267,11 +270,11 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
self.apply_to_conv = apply_to_conv
|
self.apply_to_conv = apply_to_conv
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
print(f"create LoRA network from weights")
|
logger.info("create LoRA network from weights")
|
||||||
else:
|
else:
|
||||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||||||
if self.apply_to_conv:
|
if self.apply_to_conv:
|
||||||
print(f"apply LoRA to Conv2d with kernel size (3,3).")
|
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
||||||
@@ -308,7 +311,7 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
return loras
|
return loras
|
||||||
|
|
||||||
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||||
@@ -316,7 +319,7 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
self.unet_loras = create_modules(True, unet, target_modules)
|
self.unet_loras = create_modules(True, unet, target_modules)
|
||||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
@@ -336,12 +339,12 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||||
if apply_text_encoder:
|
if apply_text_encoder:
|
||||||
print("enable LoRA for text encoder")
|
logger.info("enable LoRA for text encoder")
|
||||||
else:
|
else:
|
||||||
self.text_encoder_loras = []
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
if apply_unet:
|
if apply_unet:
|
||||||
print("enable LoRA for U-Net")
|
logger.info("enable LoRA for U-Net")
|
||||||
else:
|
else:
|
||||||
self.unet_loras = []
|
self.unet_loras = []
|
||||||
|
|
||||||
@@ -359,12 +362,12 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
apply_unet = True
|
apply_unet = True
|
||||||
|
|
||||||
if apply_text_encoder:
|
if apply_text_encoder:
|
||||||
print("enable LoRA for text encoder")
|
logger.info("enable LoRA for text encoder")
|
||||||
else:
|
else:
|
||||||
self.text_encoder_loras = []
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
if apply_unet:
|
if apply_unet:
|
||||||
print("enable LoRA for U-Net")
|
logger.info("enable LoRA for U-Net")
|
||||||
else:
|
else:
|
||||||
self.unet_loras = []
|
self.unet_loras = []
|
||||||
|
|
||||||
@@ -375,7 +378,7 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||||
lora.merge_to(sd_for_lora, dtype, device)
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
|
||||||
print(f"weights are merged")
|
logger.info(f"weights are merged")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ from safetensors.torch import load_file, save_file, safe_open
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from library import train_util, model_util
|
from library import train_util, model_util
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def load_state_dict(file_name):
|
def load_state_dict(file_name):
|
||||||
if model_util.is_safetensors(file_name):
|
if model_util.is_safetensors(file_name):
|
||||||
@@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit):
|
|||||||
rank = value.size()[0]
|
rank = value.size()[0]
|
||||||
if rank > max_rank:
|
if rank > max_rank:
|
||||||
max_rank = rank
|
max_rank = rank
|
||||||
print(f"Max rank: {max_rank}")
|
logger.info(f"Max rank: {max_rank}")
|
||||||
|
|
||||||
rank = unit
|
rank = unit
|
||||||
split_models = []
|
split_models = []
|
||||||
new_alpha = None
|
new_alpha = None
|
||||||
while rank < max_rank:
|
while rank < max_rank:
|
||||||
print(f"Splitting rank {rank}")
|
logger.info(f"Splitting rank {rank}")
|
||||||
new_sd = {}
|
new_sd = {}
|
||||||
for key, value in lora_sd.items():
|
for key, value in lora_sd.items():
|
||||||
if "lora_down" in key:
|
if "lora_down" in key:
|
||||||
@@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit):
|
|||||||
# なぜかscaleするとおかしくなる……
|
# なぜかscaleするとおかしくなる……
|
||||||
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
||||||
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
||||||
# print(key, value.size(), this_rank, rank, value, scale)
|
# logger.info(key, value.size(), this_rank, rank, value, scale)
|
||||||
# new_alpha = value * scale # always same
|
# new_alpha = value * scale # always same
|
||||||
# new_sd[key] = new_alpha
|
# new_sd[key] = new_alpha
|
||||||
new_sd[key] = value
|
new_sd[key] = value
|
||||||
@@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit):
|
|||||||
|
|
||||||
|
|
||||||
def split(args):
|
def split(args):
|
||||||
print("loading Model...")
|
logger.info("loading Model...")
|
||||||
lora_sd, metadata = load_state_dict(args.model)
|
lora_sd, metadata = load_state_dict(args.model)
|
||||||
|
|
||||||
print("Splitting Model...")
|
logger.info("Splitting Model...")
|
||||||
original_rank, split_models = split_lora_model(lora_sd, args.unit)
|
original_rank, split_models = split_lora_model(lora_sd, args.unit)
|
||||||
|
|
||||||
comment = metadata.get("ss_training_comment", "")
|
comment = metadata.get("ss_training_comment", "")
|
||||||
@@ -94,7 +97,7 @@ def split(args):
|
|||||||
filename, ext = os.path.splitext(args.save_to)
|
filename, ext = os.path.splitext(args.save_to)
|
||||||
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
||||||
|
|
||||||
print(f"saving model to: {model_file_name}")
|
logger.info(f"saving model to: {model_file_name}")
|
||||||
save_to_file(model_file_name, state_dict, new_metadata)
|
save_to_file(model_file_name, state_dict, new_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,10 @@ from safetensors.torch import load_file, save_file
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from library import sai_model_spec, model_util, sdxl_model_util
|
from library import sai_model_spec, model_util, sdxl_model_util
|
||||||
import lora
|
import lora
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# CLAMP_QUANTILE = 0.99
|
# CLAMP_QUANTILE = 0.99
|
||||||
# MIN_DIFF = 1e-1
|
# MIN_DIFF = 1e-1
|
||||||
@@ -66,14 +69,14 @@ def svd(
|
|||||||
|
|
||||||
# load models
|
# load models
|
||||||
if not sdxl:
|
if not sdxl:
|
||||||
print(f"loading original SD model : {model_org}")
|
logger.info(f"loading original SD model : {model_org}")
|
||||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||||
text_encoders_o = [text_encoder_o]
|
text_encoders_o = [text_encoder_o]
|
||||||
if load_dtype is not None:
|
if load_dtype is not None:
|
||||||
text_encoder_o = text_encoder_o.to(load_dtype)
|
text_encoder_o = text_encoder_o.to(load_dtype)
|
||||||
unet_o = unet_o.to(load_dtype)
|
unet_o = unet_o.to(load_dtype)
|
||||||
|
|
||||||
print(f"loading tuned SD model : {model_tuned}")
|
logger.info(f"loading tuned SD model : {model_tuned}")
|
||||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||||
text_encoders_t = [text_encoder_t]
|
text_encoders_t = [text_encoder_t]
|
||||||
if load_dtype is not None:
|
if load_dtype is not None:
|
||||||
@@ -85,7 +88,7 @@ def svd(
|
|||||||
device_org = load_original_model_to if load_original_model_to else "cpu"
|
device_org = load_original_model_to if load_original_model_to else "cpu"
|
||||||
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
||||||
|
|
||||||
print(f"loading original SDXL model : {model_org}")
|
logger.info(f"loading original SDXL model : {model_org}")
|
||||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||||
)
|
)
|
||||||
@@ -95,7 +98,7 @@ def svd(
|
|||||||
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
||||||
unet_o = unet_o.to(load_dtype)
|
unet_o = unet_o.to(load_dtype)
|
||||||
|
|
||||||
print(f"loading original SDXL model : {model_tuned}")
|
logger.info(f"loading original SDXL model : {model_tuned}")
|
||||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||||
)
|
)
|
||||||
@@ -135,7 +138,7 @@ def svd(
|
|||||||
# Text Encoder might be same
|
# Text Encoder might be same
|
||||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||||
text_encoder_different = True
|
text_encoder_different = True
|
||||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||||
|
|
||||||
diffs[lora_name] = diff
|
diffs[lora_name] = diff
|
||||||
|
|
||||||
@@ -144,7 +147,7 @@ def svd(
|
|||||||
del text_encoder
|
del text_encoder
|
||||||
|
|
||||||
if not text_encoder_different:
|
if not text_encoder_different:
|
||||||
print("Text encoder is same. Extract U-Net only.")
|
logger.warning("Text encoder is same. Extract U-Net only.")
|
||||||
lora_network_o.text_encoder_loras = []
|
lora_network_o.text_encoder_loras = []
|
||||||
diffs = {} # clear diffs
|
diffs = {} # clear diffs
|
||||||
|
|
||||||
@@ -166,7 +169,7 @@ def svd(
|
|||||||
del unet_t
|
del unet_t
|
||||||
|
|
||||||
# make LoRA with svd
|
# make LoRA with svd
|
||||||
print("calculating by svd")
|
logger.info("calculating by svd")
|
||||||
lora_weights = {}
|
lora_weights = {}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora_name, mat in tqdm(list(diffs.items())):
|
for lora_name, mat in tqdm(list(diffs.items())):
|
||||||
@@ -185,7 +188,7 @@ def svd(
|
|||||||
if device:
|
if device:
|
||||||
mat = mat.to(device)
|
mat = mat.to(device)
|
||||||
|
|
||||||
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
# logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||||
|
|
||||||
if conv2d:
|
if conv2d:
|
||||||
@@ -230,7 +233,7 @@ def svd(
|
|||||||
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
|
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
|
||||||
|
|
||||||
info = lora_network_save.load_state_dict(lora_sd)
|
info = lora_network_save.load_state_dict(lora_sd)
|
||||||
print(f"Loading extracted LoRA weights: {info}")
|
logger.info(f"Loading extracted LoRA weights: {info}")
|
||||||
|
|
||||||
dir_name = os.path.dirname(save_to)
|
dir_name = os.path.dirname(save_to)
|
||||||
if dir_name and not os.path.exists(dir_name):
|
if dir_name and not os.path.exists(dir_name):
|
||||||
@@ -257,7 +260,7 @@ def svd(
|
|||||||
metadata.update(sai_metadata)
|
metadata.update(sai_metadata)
|
||||||
|
|
||||||
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||||
print(f"LoRA weights are saved to: {save_to}")
|
logger.info(f"LoRA weights are saved to: {save_to}")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
105
networks/lora.py
105
networks/lora.py
@@ -11,7 +11,10 @@ from transformers import CLIPTextModel
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
|
|
||||||
@@ -46,7 +49,7 @@ class LoRAModule(torch.nn.Module):
|
|||||||
# if limit_rank:
|
# if limit_rank:
|
||||||
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
||||||
# if self.lora_dim != lora_dim:
|
# if self.lora_dim != lora_dim:
|
||||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||||
# else:
|
# else:
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
|
|
||||||
@@ -177,7 +180,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
else:
|
else:
|
||||||
# conv2d 3x3
|
# conv2d 3x3
|
||||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
weight = weight + self.multiplier * conved * self.scale
|
weight = weight + self.multiplier * conved * self.scale
|
||||||
|
|
||||||
# set weight to org_module
|
# set weight to org_module
|
||||||
@@ -216,7 +219,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
self.region_mask = None
|
self.region_mask = None
|
||||||
|
|
||||||
def default_forward(self, x):
|
def default_forward(self, x):
|
||||||
# print("default_forward", self.lora_name, x.size())
|
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -245,7 +248,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
if mask is None:
|
if mask is None:
|
||||||
# raise ValueError(f"mask is None for resolution {area}")
|
# raise ValueError(f"mask is None for resolution {area}")
|
||||||
# emb_layers in SDXL doesn't have mask
|
# emb_layers in SDXL doesn't have mask
|
||||||
# print(f"mask is None for resolution {area}, {x.size()}")
|
# logger.info(f"mask is None for resolution {area}, {x.size()}")
|
||||||
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
||||||
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
||||||
if len(x.size()) != 4:
|
if len(x.size()) != 4:
|
||||||
@@ -262,7 +265,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
# apply mask for LoRA result
|
# apply mask for LoRA result
|
||||||
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
mask = self.get_mask_for_x(lx)
|
mask = self.get_mask_for_x(lx)
|
||||||
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
# logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}")
|
||||||
lx = lx * mask
|
lx = lx * mask
|
||||||
|
|
||||||
x = self.org_forward(x)
|
x = self.org_forward(x)
|
||||||
@@ -291,7 +294,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
if has_real_uncond:
|
if has_real_uncond:
|
||||||
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
||||||
|
|
||||||
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
# logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
|
||||||
return query
|
return query
|
||||||
|
|
||||||
def sub_prompt_forward(self, x):
|
def sub_prompt_forward(self, x):
|
||||||
@@ -306,7 +309,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
lx = x[emb_idx :: self.network.num_sub_prompts]
|
lx = x[emb_idx :: self.network.num_sub_prompts]
|
||||||
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
||||||
|
|
||||||
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
# logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
|
||||||
|
|
||||||
x = self.org_forward(x)
|
x = self.org_forward(x)
|
||||||
x[emb_idx :: self.network.num_sub_prompts] += lx
|
x[emb_idx :: self.network.num_sub_prompts] += lx
|
||||||
@@ -314,7 +317,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def to_out_forward(self, x):
|
def to_out_forward(self, x):
|
||||||
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
# logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
|
||||||
|
|
||||||
if self.network.is_last_network:
|
if self.network.is_last_network:
|
||||||
masks = [None] * self.network.num_sub_prompts
|
masks = [None] * self.network.num_sub_prompts
|
||||||
@@ -332,7 +335,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
)
|
)
|
||||||
self.network.shared[self.lora_name] = (lx, masks)
|
self.network.shared[self.lora_name] = (lx, masks)
|
||||||
|
|
||||||
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
# logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
|
||||||
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
||||||
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
||||||
|
|
||||||
@@ -351,7 +354,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
if has_real_uncond:
|
if has_real_uncond:
|
||||||
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||||
|
|
||||||
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
# logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
|
||||||
# if num_sub_prompts > num of LoRAs, fill with zero
|
# if num_sub_prompts > num of LoRAs, fill with zero
|
||||||
for i in range(len(masks)):
|
for i in range(len(masks)):
|
||||||
if masks[i] is None:
|
if masks[i] is None:
|
||||||
@@ -374,7 +377,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
x1 = x1 + lx1
|
x1 = x1 + lx1
|
||||||
out[self.network.batch_size + i] = x1
|
out[self.network.batch_size + i] = x1
|
||||||
|
|
||||||
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
# logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -511,7 +514,7 @@ def get_block_dims_and_alphas(
|
|||||||
len(block_dims) == num_total_blocks
|
len(block_dims) == num_total_blocks
|
||||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||||
else:
|
else:
|
||||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||||
block_dims = [network_dim] * num_total_blocks
|
block_dims = [network_dim] * num_total_blocks
|
||||||
|
|
||||||
if block_alphas is not None:
|
if block_alphas is not None:
|
||||||
@@ -520,7 +523,7 @@ def get_block_dims_and_alphas(
|
|||||||
len(block_alphas) == num_total_blocks
|
len(block_alphas) == num_total_blocks
|
||||||
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
||||||
else:
|
else:
|
||||||
print(
|
logger.warning(
|
||||||
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
||||||
)
|
)
|
||||||
block_alphas = [network_alpha] * num_total_blocks
|
block_alphas = [network_alpha] * num_total_blocks
|
||||||
@@ -540,13 +543,13 @@ def get_block_dims_and_alphas(
|
|||||||
else:
|
else:
|
||||||
if conv_alpha is None:
|
if conv_alpha is None:
|
||||||
conv_alpha = 1.0
|
conv_alpha = 1.0
|
||||||
print(
|
logger.warning(
|
||||||
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
||||||
)
|
)
|
||||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||||
else:
|
else:
|
||||||
if conv_dim is not None:
|
if conv_dim is not None:
|
||||||
print(
|
logger.warning(
|
||||||
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
||||||
)
|
)
|
||||||
conv_block_dims = [conv_dim] * num_total_blocks
|
conv_block_dims = [conv_dim] * num_total_blocks
|
||||||
@@ -586,7 +589,7 @@ def get_block_lr_weight(
|
|||||||
elif name == "zeros":
|
elif name == "zeros":
|
||||||
return [0.0 + base_lr] * max_len
|
return [0.0 + base_lr] * max_len
|
||||||
else:
|
else:
|
||||||
print(
|
logger.error(
|
||||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||||
% (name)
|
% (name)
|
||||||
)
|
)
|
||||||
@@ -598,14 +601,14 @@ def get_block_lr_weight(
|
|||||||
up_lr_weight = get_list(up_lr_weight)
|
up_lr_weight = get_list(up_lr_weight)
|
||||||
|
|
||||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||||
up_lr_weight = up_lr_weight[:max_len]
|
up_lr_weight = up_lr_weight[:max_len]
|
||||||
down_lr_weight = down_lr_weight[:max_len]
|
down_lr_weight = down_lr_weight[:max_len]
|
||||||
|
|
||||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||||
|
|
||||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||||
@@ -613,24 +616,24 @@ def get_block_lr_weight(
|
|||||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||||
|
|
||||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||||
print("apply block learning rate / 階層別学習率を適用します。")
|
logger.info("apply block learning rate / 階層別学習率を適用します。")
|
||||||
if down_lr_weight != None:
|
if down_lr_weight != None:
|
||||||
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
|
||||||
else:
|
else:
|
||||||
print("down_lr_weight: all 1.0, すべて1.0")
|
logger.info("down_lr_weight: all 1.0, すべて1.0")
|
||||||
|
|
||||||
if mid_lr_weight != None:
|
if mid_lr_weight != None:
|
||||||
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||||
print("mid_lr_weight:", mid_lr_weight)
|
logger.info(f"mid_lr_weight: {mid_lr_weight}")
|
||||||
else:
|
else:
|
||||||
print("mid_lr_weight: 1.0")
|
logger.info("mid_lr_weight: 1.0")
|
||||||
|
|
||||||
if up_lr_weight != None:
|
if up_lr_weight != None:
|
||||||
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
|
||||||
else:
|
else:
|
||||||
print("up_lr_weight: all 1.0, すべて1.0")
|
logger.info("up_lr_weight: all 1.0, すべて1.0")
|
||||||
|
|
||||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||||
|
|
||||||
@@ -711,7 +714,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||||||
elif "lora_down" in key:
|
elif "lora_down" in key:
|
||||||
dim = value.size()[0]
|
dim = value.size()[0]
|
||||||
modules_dim[lora_name] = dim
|
modules_dim[lora_name] = dim
|
||||||
# print(lora_name, value.size(), dim)
|
# logger.info(lora_name, value.size(), dim)
|
||||||
|
|
||||||
# support old LoRA without alpha
|
# support old LoRA without alpha
|
||||||
for key in modules_dim.keys():
|
for key in modules_dim.keys():
|
||||||
@@ -786,20 +789,20 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
print(f"create LoRA network from weights")
|
logger.info(f"create LoRA network from weights")
|
||||||
elif block_dims is not None:
|
elif block_dims is not None:
|
||||||
print(f"create LoRA network from block_dims")
|
logger.info(f"create LoRA network from block_dims")
|
||||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||||
print(f"block_dims: {block_dims}")
|
logger.info(f"block_dims: {block_dims}")
|
||||||
print(f"block_alphas: {block_alphas}")
|
logger.info(f"block_alphas: {block_alphas}")
|
||||||
if conv_block_dims is not None:
|
if conv_block_dims is not None:
|
||||||
print(f"conv_block_dims: {conv_block_dims}")
|
logger.info(f"conv_block_dims: {conv_block_dims}")
|
||||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||||
else:
|
else:
|
||||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||||
if self.conv_lora_dim is not None:
|
if self.conv_lora_dim is not None:
|
||||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(
|
def create_modules(
|
||||||
@@ -884,15 +887,15 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
for i, text_encoder in enumerate(text_encoders):
|
for i, text_encoder in enumerate(text_encoders):
|
||||||
if len(text_encoders) > 1:
|
if len(text_encoders) > 1:
|
||||||
index = i + 1
|
index = i + 1
|
||||||
print(f"create LoRA for Text Encoder {index}:")
|
logger.info(f"create LoRA for Text Encoder {index}:")
|
||||||
else:
|
else:
|
||||||
index = None
|
index = None
|
||||||
print(f"create LoRA for Text Encoder:")
|
logger.info(f"create LoRA for Text Encoder:")
|
||||||
|
|
||||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
self.text_encoder_loras.extend(text_encoder_loras)
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
skipped_te += skipped
|
skipped_te += skipped
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||||
@@ -900,15 +903,15 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
skipped = skipped_te + skipped_un
|
skipped = skipped_te + skipped_un
|
||||||
if varbose and len(skipped) > 0:
|
if varbose and len(skipped) > 0:
|
||||||
print(
|
logger.warning(
|
||||||
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||||
)
|
)
|
||||||
for name in skipped:
|
for name in skipped:
|
||||||
print(f"\t{name}")
|
logger.info(f"\t{name}")
|
||||||
|
|
||||||
self.up_lr_weight: List[float] = None
|
self.up_lr_weight: List[float] = None
|
||||||
self.down_lr_weight: List[float] = None
|
self.down_lr_weight: List[float] = None
|
||||||
@@ -939,12 +942,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||||
if apply_text_encoder:
|
if apply_text_encoder:
|
||||||
print("enable LoRA for text encoder")
|
logger.info("enable LoRA for text encoder")
|
||||||
else:
|
else:
|
||||||
self.text_encoder_loras = []
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
if apply_unet:
|
if apply_unet:
|
||||||
print("enable LoRA for U-Net")
|
logger.info("enable LoRA for U-Net")
|
||||||
else:
|
else:
|
||||||
self.unet_loras = []
|
self.unet_loras = []
|
||||||
|
|
||||||
@@ -966,12 +969,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
apply_unet = True
|
apply_unet = True
|
||||||
|
|
||||||
if apply_text_encoder:
|
if apply_text_encoder:
|
||||||
print("enable LoRA for text encoder")
|
logger.info("enable LoRA for text encoder")
|
||||||
else:
|
else:
|
||||||
self.text_encoder_loras = []
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
if apply_unet:
|
if apply_unet:
|
||||||
print("enable LoRA for U-Net")
|
logger.info("enable LoRA for U-Net")
|
||||||
else:
|
else:
|
||||||
self.unet_loras = []
|
self.unet_loras = []
|
||||||
|
|
||||||
@@ -982,7 +985,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||||
lora.merge_to(sd_for_lora, dtype, device)
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
|
||||||
print(f"weights are merged")
|
logger.info(f"weights are merged")
|
||||||
|
|
||||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
||||||
def set_block_lr_weight(
|
def set_block_lr_weight(
|
||||||
@@ -1128,7 +1131,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
device = ref_weight.device
|
device = ref_weight.device
|
||||||
|
|
||||||
def resize_add(mh, mw):
|
def resize_add(mh, mw):
|
||||||
# print(mh, mw, mh * mw)
|
# logger.info(mh, mw, mh * mw)
|
||||||
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
||||||
m = m.to(device, dtype=dtype)
|
m = m.to(device, dtype=dtype)
|
||||||
mask_dic[mh * mw] = m
|
mask_dic[mh * mw] = m
|
||||||
|
|||||||
@@ -10,7 +10,10 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
import torch
|
import torch
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def make_unet_conversion_map() -> Dict[str, str]:
|
def make_unet_conversion_map() -> Dict[str, str]:
|
||||||
unet_conversion_map_layer = []
|
unet_conversion_map_layer = []
|
||||||
@@ -248,7 +251,7 @@ def create_network_from_weights(
|
|||||||
elif "lora_down" in key:
|
elif "lora_down" in key:
|
||||||
dim = value.size()[0]
|
dim = value.size()[0]
|
||||||
modules_dim[lora_name] = dim
|
modules_dim[lora_name] = dim
|
||||||
# print(lora_name, value.size(), dim)
|
# logger.info(f"{lora_name} {value.size()} {dim}")
|
||||||
|
|
||||||
# support old LoRA without alpha
|
# support old LoRA without alpha
|
||||||
for key in modules_dim.keys():
|
for key in modules_dim.keys():
|
||||||
@@ -291,12 +294,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
|
|
||||||
print(f"create LoRA network from weights")
|
logger.info("create LoRA network from weights")
|
||||||
|
|
||||||
# convert SDXL Stability AI's U-Net modules to Diffusers
|
# convert SDXL Stability AI's U-Net modules to Diffusers
|
||||||
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
||||||
if converted:
|
if converted:
|
||||||
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(
|
def create_modules(
|
||||||
@@ -331,7 +334,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
lora_name = lora_name.replace(".", "_")
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
if lora_name not in modules_dim:
|
if lora_name not in modules_dim:
|
||||||
# print(f"skipped {lora_name} (not found in modules_dim)")
|
# logger.info(f"skipped {lora_name} (not found in modules_dim)")
|
||||||
skipped.append(lora_name)
|
skipped.append(lora_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -362,18 +365,18 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
self.text_encoder_loras.extend(text_encoder_loras)
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
skipped_te += skipped
|
skipped_te += skipped
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
if len(skipped_te) > 0:
|
if len(skipped_te) > 0:
|
||||||
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
||||||
|
|
||||||
# extend U-Net target modules to include Conv2d 3x3
|
# extend U-Net target modules to include Conv2d 3x3
|
||||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
self.unet_loras: List[LoRAModule]
|
self.unet_loras: List[LoRAModule]
|
||||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
if len(skipped_un) > 0:
|
if len(skipped_un) > 0:
|
||||||
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
||||||
|
|
||||||
# assertion
|
# assertion
|
||||||
names = set()
|
names = set()
|
||||||
@@ -420,11 +423,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
||||||
if apply_text_encoder:
|
if apply_text_encoder:
|
||||||
print("enable LoRA for text encoder")
|
logger.info("enable LoRA for text encoder")
|
||||||
for lora in self.text_encoder_loras:
|
for lora in self.text_encoder_loras:
|
||||||
lora.apply_to(multiplier)
|
lora.apply_to(multiplier)
|
||||||
if apply_unet:
|
if apply_unet:
|
||||||
print("enable LoRA for U-Net")
|
logger.info("enable LoRA for U-Net")
|
||||||
for lora in self.unet_loras:
|
for lora in self.unet_loras:
|
||||||
lora.apply_to(multiplier)
|
lora.apply_to(multiplier)
|
||||||
|
|
||||||
@@ -433,16 +436,16 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
lora.unapply_to()
|
lora.unapply_to()
|
||||||
|
|
||||||
def merge_to(self, multiplier=1.0):
|
def merge_to(self, multiplier=1.0):
|
||||||
print("merge LoRA weights to original weights")
|
logger.info("merge LoRA weights to original weights")
|
||||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||||
lora.merge_to(multiplier)
|
lora.merge_to(multiplier)
|
||||||
print(f"weights are merged")
|
logger.info(f"weights are merged")
|
||||||
|
|
||||||
def restore_from(self, multiplier=1.0):
|
def restore_from(self, multiplier=1.0):
|
||||||
print("restore LoRA weights from original weights")
|
logger.info("restore LoRA weights from original weights")
|
||||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||||
lora.restore_from(multiplier)
|
lora.restore_from(multiplier)
|
||||||
print(f"weights are restored")
|
logger.info(f"weights are restored")
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||||
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
||||||
@@ -463,7 +466,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
my_state_dict = self.state_dict()
|
my_state_dict = self.state_dict()
|
||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
if state_dict[key].size() != my_state_dict[key].size():
|
if state_dict[key].size() != my_state_dict[key].size():
|
||||||
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
# logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
||||||
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
||||||
|
|
||||||
return super().load_state_dict(state_dict, strict)
|
return super().load_state_dict(state_dict, strict)
|
||||||
@@ -490,7 +493,7 @@ if __name__ == "__main__":
|
|||||||
image_prefix = args.model_id.replace("/", "_") + "_"
|
image_prefix = args.model_id.replace("/", "_") + "_"
|
||||||
|
|
||||||
# load Diffusers model
|
# load Diffusers model
|
||||||
print(f"load model from {args.model_id}")
|
logger.info(f"load model from {args.model_id}")
|
||||||
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
|
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
|
||||||
if args.sdxl:
|
if args.sdxl:
|
||||||
# use_safetensors=True does not work with 0.18.2
|
# use_safetensors=True does not work with 0.18.2
|
||||||
@@ -503,7 +506,7 @@ if __name__ == "__main__":
|
|||||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
|
||||||
|
|
||||||
# load LoRA weights
|
# load LoRA weights
|
||||||
print(f"load LoRA weights from {args.lora_weights}")
|
logger.info(f"load LoRA weights from {args.lora_weights}")
|
||||||
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
|
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
@@ -512,10 +515,10 @@ if __name__ == "__main__":
|
|||||||
lora_sd = torch.load(args.lora_weights)
|
lora_sd = torch.load(args.lora_weights)
|
||||||
|
|
||||||
# create by LoRA weights and load weights
|
# create by LoRA weights and load weights
|
||||||
print(f"create LoRA network")
|
logger.info(f"create LoRA network")
|
||||||
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
|
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
|
||||||
|
|
||||||
print(f"load LoRA network weights")
|
logger.info(f"load LoRA network weights")
|
||||||
lora_network.load_state_dict(lora_sd)
|
lora_network.load_state_dict(lora_sd)
|
||||||
|
|
||||||
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
|
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
|
||||||
@@ -544,34 +547,34 @@ if __name__ == "__main__":
|
|||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
# create image with original weights
|
# create image with original weights
|
||||||
print(f"create image with original weights")
|
logger.info(f"create image with original weights")
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
image.save(image_prefix + "original.png")
|
image.save(image_prefix + "original.png")
|
||||||
|
|
||||||
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
|
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
|
||||||
print(f"apply LoRA network to the model")
|
logger.info(f"apply LoRA network to the model")
|
||||||
lora_network.apply_to(multiplier=1.0)
|
lora_network.apply_to(multiplier=1.0)
|
||||||
|
|
||||||
print(f"create image with applied LoRA")
|
logger.info(f"create image with applied LoRA")
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
image.save(image_prefix + "applied_lora.png")
|
image.save(image_prefix + "applied_lora.png")
|
||||||
|
|
||||||
# unapply LoRA network to the model
|
# unapply LoRA network to the model
|
||||||
print(f"unapply LoRA network to the model")
|
logger.info(f"unapply LoRA network to the model")
|
||||||
lora_network.unapply_to()
|
lora_network.unapply_to()
|
||||||
|
|
||||||
print(f"create image with unapplied LoRA")
|
logger.info(f"create image with unapplied LoRA")
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
image.save(image_prefix + "unapplied_lora.png")
|
image.save(image_prefix + "unapplied_lora.png")
|
||||||
|
|
||||||
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
|
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
|
||||||
print(f"merge LoRA network to the model")
|
logger.info(f"merge LoRA network to the model")
|
||||||
lora_network.merge_to(multiplier=1.0)
|
lora_network.merge_to(multiplier=1.0)
|
||||||
|
|
||||||
print(f"create image with LoRA")
|
logger.info(f"create image with LoRA")
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
image.save(image_prefix + "merged_lora.png")
|
image.save(image_prefix + "merged_lora.png")
|
||||||
@@ -579,31 +582,31 @@ if __name__ == "__main__":
|
|||||||
# restore (unmerge) LoRA weights: numerically unstable
|
# restore (unmerge) LoRA weights: numerically unstable
|
||||||
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
|
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
|
||||||
# 保存したstate_dictから元の重みを復元するのが確実
|
# 保存したstate_dictから元の重みを復元するのが確実
|
||||||
print(f"restore (unmerge) LoRA weights")
|
logger.info(f"restore (unmerge) LoRA weights")
|
||||||
lora_network.restore_from(multiplier=1.0)
|
lora_network.restore_from(multiplier=1.0)
|
||||||
|
|
||||||
print(f"create image without LoRA")
|
logger.info(f"create image without LoRA")
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
image.save(image_prefix + "unmerged_lora.png")
|
image.save(image_prefix + "unmerged_lora.png")
|
||||||
|
|
||||||
# restore original weights
|
# restore original weights
|
||||||
print(f"restore original weights")
|
logger.info(f"restore original weights")
|
||||||
pipe.unet.load_state_dict(org_unet_sd)
|
pipe.unet.load_state_dict(org_unet_sd)
|
||||||
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
|
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
|
||||||
if args.sdxl:
|
if args.sdxl:
|
||||||
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
|
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
|
||||||
|
|
||||||
print(f"create image with restored original weights")
|
logger.info(f"create image with restored original weights")
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
image.save(image_prefix + "restore_original.png")
|
image.save(image_prefix + "restore_original.png")
|
||||||
|
|
||||||
# use convenience function to merge LoRA weights
|
# use convenience function to merge LoRA weights
|
||||||
print(f"merge LoRA weights with convenience function")
|
logger.info(f"merge LoRA weights with convenience function")
|
||||||
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
||||||
|
|
||||||
print(f"create image with merged LoRA weights")
|
logger.info(f"create image with merged LoRA weights")
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||||
image.save(image_prefix + "convenience_merged_lora.png")
|
image.save(image_prefix + "convenience_merged_lora.png")
|
||||||
|
|||||||
@@ -14,7 +14,10 @@ from transformers import CLIPTextModel
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
|
|
||||||
@@ -49,7 +52,7 @@ class LoRAModule(torch.nn.Module):
|
|||||||
# if limit_rank:
|
# if limit_rank:
|
||||||
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
||||||
# if self.lora_dim != lora_dim:
|
# if self.lora_dim != lora_dim:
|
||||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||||
# else:
|
# else:
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
|
|
||||||
@@ -197,7 +200,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
else:
|
else:
|
||||||
# conv2d 3x3
|
# conv2d 3x3
|
||||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
weight = weight + self.multiplier * conved * self.scale
|
weight = weight + self.multiplier * conved * self.scale
|
||||||
|
|
||||||
# set weight to org_module
|
# set weight to org_module
|
||||||
@@ -236,7 +239,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
self.region_mask = None
|
self.region_mask = None
|
||||||
|
|
||||||
def default_forward(self, x):
|
def default_forward(self, x):
|
||||||
# print("default_forward", self.lora_name, x.size())
|
# logger.info("default_forward", self.lora_name, x.size())
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -278,7 +281,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
# apply mask for LoRA result
|
# apply mask for LoRA result
|
||||||
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
mask = self.get_mask_for_x(lx)
|
mask = self.get_mask_for_x(lx)
|
||||||
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
# logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||||
lx = lx * mask
|
lx = lx * mask
|
||||||
|
|
||||||
x = self.org_forward(x)
|
x = self.org_forward(x)
|
||||||
@@ -307,7 +310,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
if has_real_uncond:
|
if has_real_uncond:
|
||||||
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
||||||
|
|
||||||
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
# logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||||
return query
|
return query
|
||||||
|
|
||||||
def sub_prompt_forward(self, x):
|
def sub_prompt_forward(self, x):
|
||||||
@@ -322,7 +325,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
lx = x[emb_idx :: self.network.num_sub_prompts]
|
lx = x[emb_idx :: self.network.num_sub_prompts]
|
||||||
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
||||||
|
|
||||||
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
# logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||||
|
|
||||||
x = self.org_forward(x)
|
x = self.org_forward(x)
|
||||||
x[emb_idx :: self.network.num_sub_prompts] += lx
|
x[emb_idx :: self.network.num_sub_prompts] += lx
|
||||||
@@ -330,7 +333,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def to_out_forward(self, x):
|
def to_out_forward(self, x):
|
||||||
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
# logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||||
|
|
||||||
if self.network.is_last_network:
|
if self.network.is_last_network:
|
||||||
masks = [None] * self.network.num_sub_prompts
|
masks = [None] * self.network.num_sub_prompts
|
||||||
@@ -348,7 +351,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
)
|
)
|
||||||
self.network.shared[self.lora_name] = (lx, masks)
|
self.network.shared[self.lora_name] = (lx, masks)
|
||||||
|
|
||||||
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
# logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||||
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
||||||
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
||||||
|
|
||||||
@@ -367,7 +370,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
if has_real_uncond:
|
if has_real_uncond:
|
||||||
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||||
|
|
||||||
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
# logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||||
# for i in range(len(masks)):
|
# for i in range(len(masks)):
|
||||||
# if masks[i] is None:
|
# if masks[i] is None:
|
||||||
# masks[i] = torch.zeros_like(masks[-1])
|
# masks[i] = torch.zeros_like(masks[-1])
|
||||||
@@ -389,7 +392,7 @@ class LoRAInfModule(LoRAModule):
|
|||||||
x1 = x1 + lx1
|
x1 = x1 + lx1
|
||||||
out[self.network.batch_size + i] = x1
|
out[self.network.batch_size + i] = x1
|
||||||
|
|
||||||
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
# logger.info("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -526,7 +529,7 @@ def get_block_dims_and_alphas(
|
|||||||
len(block_dims) == num_total_blocks
|
len(block_dims) == num_total_blocks
|
||||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||||
else:
|
else:
|
||||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||||
block_dims = [network_dim] * num_total_blocks
|
block_dims = [network_dim] * num_total_blocks
|
||||||
|
|
||||||
if block_alphas is not None:
|
if block_alphas is not None:
|
||||||
@@ -535,7 +538,7 @@ def get_block_dims_and_alphas(
|
|||||||
len(block_alphas) == num_total_blocks
|
len(block_alphas) == num_total_blocks
|
||||||
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
||||||
else:
|
else:
|
||||||
print(
|
logger.warning(
|
||||||
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
||||||
)
|
)
|
||||||
block_alphas = [network_alpha] * num_total_blocks
|
block_alphas = [network_alpha] * num_total_blocks
|
||||||
@@ -555,13 +558,13 @@ def get_block_dims_and_alphas(
|
|||||||
else:
|
else:
|
||||||
if conv_alpha is None:
|
if conv_alpha is None:
|
||||||
conv_alpha = 1.0
|
conv_alpha = 1.0
|
||||||
print(
|
logger.warning(
|
||||||
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
||||||
)
|
)
|
||||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||||
else:
|
else:
|
||||||
if conv_dim is not None:
|
if conv_dim is not None:
|
||||||
print(
|
logger.warning(
|
||||||
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
||||||
)
|
)
|
||||||
conv_block_dims = [conv_dim] * num_total_blocks
|
conv_block_dims = [conv_dim] * num_total_blocks
|
||||||
@@ -601,7 +604,7 @@ def get_block_lr_weight(
|
|||||||
elif name == "zeros":
|
elif name == "zeros":
|
||||||
return [0.0 + base_lr] * max_len
|
return [0.0 + base_lr] * max_len
|
||||||
else:
|
else:
|
||||||
print(
|
logger.error(
|
||||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||||
% (name)
|
% (name)
|
||||||
)
|
)
|
||||||
@@ -613,14 +616,14 @@ def get_block_lr_weight(
|
|||||||
up_lr_weight = get_list(up_lr_weight)
|
up_lr_weight = get_list(up_lr_weight)
|
||||||
|
|
||||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||||
up_lr_weight = up_lr_weight[:max_len]
|
up_lr_weight = up_lr_weight[:max_len]
|
||||||
down_lr_weight = down_lr_weight[:max_len]
|
down_lr_weight = down_lr_weight[:max_len]
|
||||||
|
|
||||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||||
|
|
||||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||||
@@ -628,24 +631,24 @@ def get_block_lr_weight(
|
|||||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||||
|
|
||||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||||
print("apply block learning rate / 階層別学習率を適用します。")
|
logger.info("apply block learning rate / 階層別学習率を適用します。")
|
||||||
if down_lr_weight != None:
|
if down_lr_weight != None:
|
||||||
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
|
||||||
else:
|
else:
|
||||||
print("down_lr_weight: all 1.0, すべて1.0")
|
logger.info("down_lr_weight: all 1.0, すべて1.0")
|
||||||
|
|
||||||
if mid_lr_weight != None:
|
if mid_lr_weight != None:
|
||||||
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||||
print("mid_lr_weight:", mid_lr_weight)
|
logger.info(f"mid_lr_weight: {mid_lr_weight}")
|
||||||
else:
|
else:
|
||||||
print("mid_lr_weight: 1.0")
|
logger.info("mid_lr_weight: 1.0")
|
||||||
|
|
||||||
if up_lr_weight != None:
|
if up_lr_weight != None:
|
||||||
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
|
||||||
else:
|
else:
|
||||||
print("up_lr_weight: all 1.0, すべて1.0")
|
logger.info("up_lr_weight: all 1.0, すべて1.0")
|
||||||
|
|
||||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||||
|
|
||||||
@@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||||||
elif "lora_down" in key:
|
elif "lora_down" in key:
|
||||||
dim = value.size()[0]
|
dim = value.size()[0]
|
||||||
modules_dim[lora_name] = dim
|
modules_dim[lora_name] = dim
|
||||||
# print(lora_name, value.size(), dim)
|
# logger.info(lora_name, value.size(), dim)
|
||||||
|
|
||||||
# support old LoRA without alpha
|
# support old LoRA without alpha
|
||||||
for key in modules_dim.keys():
|
for key in modules_dim.keys():
|
||||||
@@ -801,20 +804,20 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
print(f"create LoRA network from weights")
|
logger.info(f"create LoRA network from weights")
|
||||||
elif block_dims is not None:
|
elif block_dims is not None:
|
||||||
print(f"create LoRA network from block_dims")
|
logger.info(f"create LoRA network from block_dims")
|
||||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||||
print(f"block_dims: {block_dims}")
|
logger.info(f"block_dims: {block_dims}")
|
||||||
print(f"block_alphas: {block_alphas}")
|
logger.info(f"block_alphas: {block_alphas}")
|
||||||
if conv_block_dims is not None:
|
if conv_block_dims is not None:
|
||||||
print(f"conv_block_dims: {conv_block_dims}")
|
logger.info(f"conv_block_dims: {conv_block_dims}")
|
||||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||||
else:
|
else:
|
||||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||||
if self.conv_lora_dim is not None:
|
if self.conv_lora_dim is not None:
|
||||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(
|
def create_modules(
|
||||||
@@ -899,15 +902,15 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
for i, text_encoder in enumerate(text_encoders):
|
for i, text_encoder in enumerate(text_encoders):
|
||||||
if len(text_encoders) > 1:
|
if len(text_encoders) > 1:
|
||||||
index = i + 1
|
index = i + 1
|
||||||
print(f"create LoRA for Text Encoder {index}:")
|
logger.info(f"create LoRA for Text Encoder {index}:")
|
||||||
else:
|
else:
|
||||||
index = None
|
index = None
|
||||||
print(f"create LoRA for Text Encoder:")
|
logger.info(f"create LoRA for Text Encoder:")
|
||||||
|
|
||||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
self.text_encoder_loras.extend(text_encoder_loras)
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
skipped_te += skipped
|
skipped_te += skipped
|
||||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||||
@@ -915,15 +918,15 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
skipped = skipped_te + skipped_un
|
skipped = skipped_te + skipped_un
|
||||||
if varbose and len(skipped) > 0:
|
if varbose and len(skipped) > 0:
|
||||||
print(
|
logger.warning(
|
||||||
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||||
)
|
)
|
||||||
for name in skipped:
|
for name in skipped:
|
||||||
print(f"\t{name}")
|
logger.info(f"\t{name}")
|
||||||
|
|
||||||
self.up_lr_weight: List[float] = None
|
self.up_lr_weight: List[float] = None
|
||||||
self.down_lr_weight: List[float] = None
|
self.down_lr_weight: List[float] = None
|
||||||
@@ -954,12 +957,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||||
if apply_text_encoder:
|
if apply_text_encoder:
|
||||||
print("enable LoRA for text encoder")
|
logger.info("enable LoRA for text encoder")
|
||||||
else:
|
else:
|
||||||
self.text_encoder_loras = []
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
if apply_unet:
|
if apply_unet:
|
||||||
print("enable LoRA for U-Net")
|
logger.info("enable LoRA for U-Net")
|
||||||
else:
|
else:
|
||||||
self.unet_loras = []
|
self.unet_loras = []
|
||||||
|
|
||||||
@@ -981,12 +984,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
apply_unet = True
|
apply_unet = True
|
||||||
|
|
||||||
if apply_text_encoder:
|
if apply_text_encoder:
|
||||||
print("enable LoRA for text encoder")
|
logger.info("enable LoRA for text encoder")
|
||||||
else:
|
else:
|
||||||
self.text_encoder_loras = []
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
if apply_unet:
|
if apply_unet:
|
||||||
print("enable LoRA for U-Net")
|
logger.info("enable LoRA for U-Net")
|
||||||
else:
|
else:
|
||||||
self.unet_loras = []
|
self.unet_loras = []
|
||||||
|
|
||||||
@@ -997,7 +1000,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||||
lora.merge_to(sd_for_lora, dtype, device)
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
|
||||||
print(f"weights are merged")
|
logger.info(f"weights are merged")
|
||||||
|
|
||||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
||||||
def set_block_lr_weight(
|
def set_block_lr_weight(
|
||||||
@@ -1144,7 +1147,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
device = ref_weight.device
|
device = ref_weight.device
|
||||||
|
|
||||||
def resize_add(mh, mw):
|
def resize_add(mh, mw):
|
||||||
# print(mh, mw, mh * mw)
|
# logger.info(mh, mw, mh * mw)
|
||||||
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
||||||
m = m.to(device, dtype=dtype)
|
m = m.to(device, dtype=dtype)
|
||||||
mask_dic[mh * mw] = m
|
mask_dic[mh * mw] = m
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ import torch
|
|||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import lora
|
import lora
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||||
@@ -20,12 +24,12 @@ def interrogate(args):
|
|||||||
weights_dtype = torch.float16
|
weights_dtype = torch.float16
|
||||||
|
|
||||||
# いろいろ準備する
|
# いろいろ準備する
|
||||||
print(f"loading SD model: {args.sd_model}")
|
logger.info(f"loading SD model: {args.sd_model}")
|
||||||
args.pretrained_model_name_or_path = args.sd_model
|
args.pretrained_model_name_or_path = args.sd_model
|
||||||
args.vae = None
|
args.vae = None
|
||||||
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
|
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
|
||||||
|
|
||||||
print(f"loading LoRA: {args.model}")
|
logger.info(f"loading LoRA: {args.model}")
|
||||||
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||||
|
|
||||||
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
||||||
@@ -35,11 +39,11 @@ def interrogate(args):
|
|||||||
has_te_weight = True
|
has_te_weight = True
|
||||||
break
|
break
|
||||||
if not has_te_weight:
|
if not has_te_weight:
|
||||||
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
||||||
return
|
return
|
||||||
del vae
|
del vae
|
||||||
|
|
||||||
print("loading tokenizer")
|
logger.info("loading tokenizer")
|
||||||
if args.v2:
|
if args.v2:
|
||||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
||||||
else:
|
else:
|
||||||
@@ -53,7 +57,7 @@ def interrogate(args):
|
|||||||
# トークンをひとつひとつ当たっていく
|
# トークンをひとつひとつ当たっていく
|
||||||
token_id_start = 0
|
token_id_start = 0
|
||||||
token_id_end = max(tokenizer.all_special_ids)
|
token_id_end = max(tokenizer.all_special_ids)
|
||||||
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
||||||
|
|
||||||
def get_all_embeddings(text_encoder):
|
def get_all_embeddings(text_encoder):
|
||||||
embs = []
|
embs = []
|
||||||
@@ -79,24 +83,24 @@ def interrogate(args):
|
|||||||
embs.extend(encoder_hidden_states)
|
embs.extend(encoder_hidden_states)
|
||||||
return torch.stack(embs)
|
return torch.stack(embs)
|
||||||
|
|
||||||
print("get original text encoder embeddings.")
|
logger.info("get original text encoder embeddings.")
|
||||||
orig_embs = get_all_embeddings(text_encoder)
|
orig_embs = get_all_embeddings(text_encoder)
|
||||||
|
|
||||||
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
||||||
info = network.load_state_dict(weights_sd, strict=False)
|
info = network.load_state_dict(weights_sd, strict=False)
|
||||||
print(f"Loading LoRA weights: {info}")
|
logger.info(f"Loading LoRA weights: {info}")
|
||||||
|
|
||||||
network.to(DEVICE, dtype=weights_dtype)
|
network.to(DEVICE, dtype=weights_dtype)
|
||||||
network.eval()
|
network.eval()
|
||||||
|
|
||||||
del unet
|
del unet
|
||||||
|
|
||||||
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||||
print("get text encoder embeddings with lora.")
|
logger.info("get text encoder embeddings with lora.")
|
||||||
lora_embs = get_all_embeddings(text_encoder)
|
lora_embs = get_all_embeddings(text_encoder)
|
||||||
|
|
||||||
# 比べる:とりあえず単純に差分の絶対値で
|
# 比べる:とりあえず単純に差分の絶対値で
|
||||||
print("comparing...")
|
logger.info("comparing...")
|
||||||
diffs = {}
|
diffs = {}
|
||||||
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
||||||
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
||||||
|
|||||||
@@ -7,7 +7,10 @@ from safetensors.torch import load_file, save_file
|
|||||||
from library import sai_model_spec, train_util
|
from library import sai_model_spec, train_util
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import lora
|
import lora
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
@@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
name_to_module[lora_name] = child_module
|
name_to_module[lora_name] = child_module
|
||||||
|
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
print(f"merging...")
|
logger.info(f"merging...")
|
||||||
for key in lora_sd.keys():
|
for key in lora_sd.keys():
|
||||||
if "lora_down" in key:
|
if "lora_down" in key:
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
@@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
# find original module for this lora
|
# find original module for this lora
|
||||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||||
if module_name not in name_to_module:
|
if module_name not in name_to_module:
|
||||||
print(f"no module found for LoRA weight: {key}")
|
logger.info(f"no module found for LoRA weight: {key}")
|
||||||
continue
|
continue
|
||||||
module = name_to_module[module_name]
|
module = name_to_module[module_name]
|
||||||
# print(f"apply {key} to {module}")
|
# logger.info(f"apply {key} to {module}")
|
||||||
|
|
||||||
down_weight = lora_sd[key]
|
down_weight = lora_sd[key]
|
||||||
up_weight = lora_sd[up_key]
|
up_weight = lora_sd[up_key]
|
||||||
@@ -104,7 +107,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
else:
|
else:
|
||||||
# conv2d 3x3
|
# conv2d 3x3
|
||||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
weight = weight + ratio * conved * scale
|
weight = weight + ratio * conved * scale
|
||||||
|
|
||||||
module.weight = torch.nn.Parameter(weight)
|
module.weight = torch.nn.Parameter(weight)
|
||||||
@@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
|||||||
v2 = None
|
v2 = None
|
||||||
base_model = None
|
base_model = None
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
if lora_metadata is not None:
|
if lora_metadata is not None:
|
||||||
@@ -151,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
|||||||
if lora_module_name not in base_alphas:
|
if lora_module_name not in base_alphas:
|
||||||
base_alphas[lora_module_name] = alpha
|
base_alphas[lora_module_name] = alpha
|
||||||
|
|
||||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||||
|
|
||||||
# merge
|
# merge
|
||||||
print(f"merging...")
|
logger.info(f"merging...")
|
||||||
for key in lora_sd.keys():
|
for key in lora_sd.keys():
|
||||||
if "alpha" in key:
|
if "alpha" in key:
|
||||||
continue
|
continue
|
||||||
@@ -196,8 +199,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
|||||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||||
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||||
|
|
||||||
print("merged model")
|
logger.info("merged model")
|
||||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||||
|
|
||||||
# check all dims are same
|
# check all dims are same
|
||||||
dims_list = list(set(base_dims.values()))
|
dims_list = list(set(base_dims.values()))
|
||||||
@@ -239,7 +242,7 @@ def merge(args):
|
|||||||
save_dtype = merge_dtype
|
save_dtype = merge_dtype
|
||||||
|
|
||||||
if args.sd_model is not None:
|
if args.sd_model is not None:
|
||||||
print(f"loading SD model: {args.sd_model}")
|
logger.info(f"loading SD model: {args.sd_model}")
|
||||||
|
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||||
|
|
||||||
@@ -264,18 +267,18 @@ def merge(args):
|
|||||||
)
|
)
|
||||||
if args.v2:
|
if args.v2:
|
||||||
# TODO read sai modelspec
|
# TODO read sai modelspec
|
||||||
print(
|
logger.warning(
|
||||||
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"saving SD model to: {args.save_to}")
|
logger.info(f"saving SD model to: {args.save_to}")
|
||||||
model_util.save_stable_diffusion_checkpoint(
|
model_util.save_stable_diffusion_checkpoint(
|
||||||
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
|
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||||
|
|
||||||
print(f"calculating hashes and creating metadata...")
|
logger.info(f"calculating hashes and creating metadata...")
|
||||||
|
|
||||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
metadata["sshs_model_hash"] = model_hash
|
metadata["sshs_model_hash"] = model_hash
|
||||||
@@ -289,12 +292,12 @@ def merge(args):
|
|||||||
)
|
)
|
||||||
if v2:
|
if v2:
|
||||||
# TODO read sai modelspec
|
# TODO read sai modelspec
|
||||||
print(
|
logger.warning(
|
||||||
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||||
)
|
)
|
||||||
metadata.update(sai_metadata)
|
metadata.update(sai_metadata)
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
logger.info(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ import torch
|
|||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import lora
|
import lora
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||||
@@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
name_to_module[lora_name] = child_module
|
name_to_module[lora_name] = child_module
|
||||||
|
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd = load_state_dict(model, merge_dtype)
|
lora_sd = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
print(f"merging...")
|
logger.info(f"merging...")
|
||||||
for key in lora_sd.keys():
|
for key in lora_sd.keys():
|
||||||
if "lora_down" in key:
|
if "lora_down" in key:
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
|||||||
# find original module for this lora
|
# find original module for this lora
|
||||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
||||||
if module_name not in name_to_module:
|
if module_name not in name_to_module:
|
||||||
print(f"no module found for LoRA weight: {key}")
|
logger.info(f"no module found for LoRA weight: {key}")
|
||||||
continue
|
continue
|
||||||
module = name_to_module[module_name]
|
module = name_to_module[module_name]
|
||||||
# print(f"apply {key} to {module}")
|
# logger.info(f"apply {key} to {module}")
|
||||||
|
|
||||||
down_weight = lora_sd[key]
|
down_weight = lora_sd[key]
|
||||||
up_weight = lora_sd[up_key]
|
up_weight = lora_sd[up_key]
|
||||||
@@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
alpha = None
|
alpha = None
|
||||||
dim = None
|
dim = None
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd = load_state_dict(model, merge_dtype)
|
lora_sd = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
print(f"merging...")
|
logger.info(f"merging...")
|
||||||
for key in lora_sd.keys():
|
for key in lora_sd.keys():
|
||||||
if 'alpha' in key:
|
if 'alpha' in key:
|
||||||
if key in merged_sd:
|
if key in merged_sd:
|
||||||
@@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
|||||||
dim = lora_sd[key].size()[0]
|
dim = lora_sd[key].size()[0]
|
||||||
merged_sd[key] = lora_sd[key] * ratio
|
merged_sd[key] = lora_sd[key] * ratio
|
||||||
|
|
||||||
print(f"dim (rank): {dim}, alpha: {alpha}")
|
logger.info(f"dim (rank): {dim}, alpha: {alpha}")
|
||||||
if alpha is None:
|
if alpha is None:
|
||||||
alpha = dim
|
alpha = dim
|
||||||
|
|
||||||
@@ -142,19 +145,21 @@ def merge(args):
|
|||||||
save_dtype = merge_dtype
|
save_dtype = merge_dtype
|
||||||
|
|
||||||
if args.sd_model is not None:
|
if args.sd_model is not None:
|
||||||
print(f"loading SD model: {args.sd_model}")
|
logger.info(f"loading SD model: {args.sd_model}")
|
||||||
|
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||||
|
|
||||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||||
|
|
||||||
print(f"\nsaving SD model to: {args.save_to}")
|
logger.info("")
|
||||||
|
logger.info(f"saving SD model to: {args.save_to}")
|
||||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||||
args.sd_model, 0, 0, save_dtype, vae)
|
args.sd_model, 0, 0, save_dtype, vae)
|
||||||
else:
|
else:
|
||||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||||
|
|
||||||
print(f"\nsaving model to: {args.save_to}")
|
logger.info(f"")
|
||||||
|
logger.info(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ from transformers import CLIPTextModel
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
|
|
||||||
@@ -237,7 +240,7 @@ class OFTNetwork(torch.nn.Module):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
print(
|
logger.info(
|
||||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -258,7 +261,7 @@ class OFTNetwork(torch.nn.Module):
|
|||||||
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
|
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
|
||||||
oft_name = prefix + "." + name + "." + child_name
|
oft_name = prefix + "." + name + "." + child_name
|
||||||
oft_name = oft_name.replace(".", "_")
|
oft_name = oft_name.replace(".", "_")
|
||||||
# print(oft_name)
|
# logger.info(oft_name)
|
||||||
|
|
||||||
oft = module_class(
|
oft = module_class(
|
||||||
oft_name,
|
oft_name,
|
||||||
@@ -279,7 +282,7 @@ class OFTNetwork(torch.nn.Module):
|
|||||||
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
||||||
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
||||||
|
|
||||||
# assertion
|
# assertion
|
||||||
names = set()
|
names = set()
|
||||||
@@ -316,7 +319,7 @@ class OFTNetwork(torch.nn.Module):
|
|||||||
|
|
||||||
# TODO refactor to common function with apply_to
|
# TODO refactor to common function with apply_to
|
||||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||||
print("enable OFT for U-Net")
|
logger.info("enable OFT for U-Net")
|
||||||
|
|
||||||
for oft in self.unet_ofts:
|
for oft in self.unet_ofts:
|
||||||
sd_for_lora = {}
|
sd_for_lora = {}
|
||||||
@@ -326,7 +329,7 @@ class OFTNetwork(torch.nn.Module):
|
|||||||
oft.load_state_dict(sd_for_lora, False)
|
oft.load_state_dict(sd_for_lora, False)
|
||||||
oft.merge_to()
|
oft.merge_to()
|
||||||
|
|
||||||
print(f"weights are merged")
|
logger.info(f"weights are merged")
|
||||||
|
|
||||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
@@ -338,11 +341,11 @@ class OFTNetwork(torch.nn.Module):
|
|||||||
for oft in ofts:
|
for oft in ofts:
|
||||||
params.extend(oft.parameters())
|
params.extend(oft.parameters())
|
||||||
|
|
||||||
# print num of params
|
# logger.info num of params
|
||||||
num_params = 0
|
num_params = 0
|
||||||
for p in params:
|
for p in params:
|
||||||
num_params += p.numel()
|
num_params += p.numel()
|
||||||
print(f"OFT params: {num_params}")
|
logger.info(f"OFT params: {num_params}")
|
||||||
return params
|
return params
|
||||||
|
|
||||||
param_data = {"params": enumerate_params(self.unet_ofts)}
|
param_data = {"params": enumerate_params(self.unet_ofts)}
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ from safetensors.torch import load_file, save_file, safe_open
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from library import train_util, model_util
|
from library import train_util, model_util
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MIN_SV = 1e-6
|
MIN_SV = 1e-6
|
||||||
|
|
||||||
@@ -206,7 +210,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
scale = network_alpha/network_dim
|
scale = network_alpha/network_dim
|
||||||
|
|
||||||
if dynamic_method:
|
if dynamic_method:
|
||||||
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
||||||
|
|
||||||
lora_down_weight = None
|
lora_down_weight = None
|
||||||
lora_up_weight = None
|
lora_up_weight = None
|
||||||
@@ -275,10 +279,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
del param_dict
|
del param_dict
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(verbose_str)
|
logger.info(verbose_str)
|
||||||
|
|
||||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||||
print("resizing complete")
|
logger.info("resizing complete")
|
||||||
return o_lora_sd, network_dim, new_alpha
|
return o_lora_sd, network_dim, new_alpha
|
||||||
|
|
||||||
|
|
||||||
@@ -304,10 +308,10 @@ def resize(args):
|
|||||||
if save_dtype is None:
|
if save_dtype is None:
|
||||||
save_dtype = merge_dtype
|
save_dtype = merge_dtype
|
||||||
|
|
||||||
print("loading Model...")
|
logger.info("loading Model...")
|
||||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||||
|
|
||||||
print("Resizing Lora...")
|
logger.info("Resizing Lora...")
|
||||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
||||||
|
|
||||||
# update metadata
|
# update metadata
|
||||||
@@ -329,7 +333,7 @@ def resize(args):
|
|||||||
metadata["sshs_model_hash"] = model_hash
|
metadata["sshs_model_hash"] = model_hash
|
||||||
metadata["sshs_legacy_hash"] = legacy_hash
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
logger.info(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ from tqdm import tqdm
|
|||||||
from library import sai_model_spec, sdxl_model_util, train_util
|
from library import sai_model_spec, sdxl_model_util, train_util
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import lora
|
import lora
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def load_state_dict(file_name, dtype):
|
def load_state_dict(file_name, dtype):
|
||||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||||
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
name_to_module[lora_name] = child_module
|
name_to_module[lora_name] = child_module
|
||||||
|
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
print(f"merging...")
|
logger.info(f"merging...")
|
||||||
for key in tqdm(lora_sd.keys()):
|
for key in tqdm(lora_sd.keys()):
|
||||||
if "lora_down" in key:
|
if "lora_down" in key:
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
@@ -78,10 +81,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
# find original module for this lora
|
# find original module for this lora
|
||||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||||
if module_name not in name_to_module:
|
if module_name not in name_to_module:
|
||||||
print(f"no module found for LoRA weight: {key}")
|
logger.info(f"no module found for LoRA weight: {key}")
|
||||||
continue
|
continue
|
||||||
module = name_to_module[module_name]
|
module = name_to_module[module_name]
|
||||||
# print(f"apply {key} to {module}")
|
# logger.info(f"apply {key} to {module}")
|
||||||
|
|
||||||
down_weight = lora_sd[key]
|
down_weight = lora_sd[key]
|
||||||
up_weight = lora_sd[up_key]
|
up_weight = lora_sd[up_key]
|
||||||
@@ -92,7 +95,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
|
|
||||||
# W <- W + U * D
|
# W <- W + U * D
|
||||||
weight = module.weight
|
weight = module.weight
|
||||||
# print(module_name, down_weight.size(), up_weight.size())
|
# logger.info(module_name, down_weight.size(), up_weight.size())
|
||||||
if len(weight.size()) == 2:
|
if len(weight.size()) == 2:
|
||||||
# linear
|
# linear
|
||||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||||
@@ -107,7 +110,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
else:
|
else:
|
||||||
# conv2d 3x3
|
# conv2d 3x3
|
||||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||||
weight = weight + ratio * conved * scale
|
weight = weight + ratio * conved * scale
|
||||||
|
|
||||||
module.weight = torch.nn.Parameter(weight)
|
module.weight = torch.nn.Parameter(weight)
|
||||||
@@ -121,7 +124,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
|||||||
v2 = None
|
v2 = None
|
||||||
base_model = None
|
base_model = None
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
if lora_metadata is not None:
|
if lora_metadata is not None:
|
||||||
@@ -154,10 +157,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
|||||||
if lora_module_name not in base_alphas:
|
if lora_module_name not in base_alphas:
|
||||||
base_alphas[lora_module_name] = alpha
|
base_alphas[lora_module_name] = alpha
|
||||||
|
|
||||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||||
|
|
||||||
# merge
|
# merge
|
||||||
print(f"merging...")
|
logger.info(f"merging...")
|
||||||
for key in tqdm(lora_sd.keys()):
|
for key in tqdm(lora_sd.keys()):
|
||||||
if "alpha" in key:
|
if "alpha" in key:
|
||||||
continue
|
continue
|
||||||
@@ -200,8 +203,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
|||||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||||
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||||
|
|
||||||
print("merged model")
|
logger.info("merged model")
|
||||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||||
|
|
||||||
# check all dims are same
|
# check all dims are same
|
||||||
dims_list = list(set(base_dims.values()))
|
dims_list = list(set(base_dims.values()))
|
||||||
@@ -243,7 +246,7 @@ def merge(args):
|
|||||||
save_dtype = merge_dtype
|
save_dtype = merge_dtype
|
||||||
|
|
||||||
if args.sd_model is not None:
|
if args.sd_model is not None:
|
||||||
print(f"loading SD model: {args.sd_model}")
|
logger.info(f"loading SD model: {args.sd_model}")
|
||||||
|
|
||||||
(
|
(
|
||||||
text_model1,
|
text_model1,
|
||||||
@@ -265,14 +268,14 @@ def merge(args):
|
|||||||
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
|
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"saving SD model to: {args.save_to}")
|
logger.info(f"saving SD model to: {args.save_to}")
|
||||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||||
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||||
|
|
||||||
print(f"calculating hashes and creating metadata...")
|
logger.info(f"calculating hashes and creating metadata...")
|
||||||
|
|
||||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
metadata["sshs_model_hash"] = model_hash
|
metadata["sshs_model_hash"] = model_hash
|
||||||
@@ -286,7 +289,7 @@ def merge(args):
|
|||||||
)
|
)
|
||||||
metadata.update(sai_metadata)
|
metadata.update(sai_metadata)
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
logger.info(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,12 @@ import torch
|
|||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from library import sai_model_spec, train_util
|
from library import sai_model_spec, train_util
|
||||||
|
import library.model_util as model_util
|
||||||
|
import lora
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
@@ -38,12 +43,12 @@ def save_to_file(file_name, state_dict, dtype, metadata):
|
|||||||
|
|
||||||
|
|
||||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||||
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||||
merged_sd = {}
|
merged_sd = {}
|
||||||
v2 = None
|
v2 = None
|
||||||
base_model = None
|
base_model = None
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
print(f"loading: {model}")
|
logger.info(f"loading: {model}")
|
||||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||||
|
|
||||||
if lora_metadata is not None:
|
if lora_metadata is not None:
|
||||||
@@ -53,7 +58,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||||
|
|
||||||
# merge
|
# merge
|
||||||
print(f"merging...")
|
logger.info(f"merging...")
|
||||||
for key in tqdm(list(lora_sd.keys())):
|
for key in tqdm(list(lora_sd.keys())):
|
||||||
if "lora_down" not in key:
|
if "lora_down" not in key:
|
||||||
continue
|
continue
|
||||||
@@ -70,7 +75,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
out_dim = up_weight.size()[0]
|
out_dim = up_weight.size()[0]
|
||||||
conv2d = len(down_weight.size()) == 4
|
conv2d = len(down_weight.size()) == 4
|
||||||
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
||||||
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
# logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
||||||
|
|
||||||
# make original weight if not exist
|
# make original weight if not exist
|
||||||
if lora_module_name not in merged_sd:
|
if lora_module_name not in merged_sd:
|
||||||
@@ -107,7 +112,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
|||||||
merged_sd[lora_module_name] = weight
|
merged_sd[lora_module_name] = weight
|
||||||
|
|
||||||
# extract from merged weights
|
# extract from merged weights
|
||||||
print("extract new lora...")
|
logger.info("extract new lora...")
|
||||||
merged_lora_sd = {}
|
merged_lora_sd = {}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||||
@@ -185,7 +190,7 @@ def merge(args):
|
|||||||
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
|
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"calculating hashes and creating metadata...")
|
logger.info(f"calculating hashes and creating metadata...")
|
||||||
|
|
||||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
metadata["sshs_model_hash"] = model_hash
|
metadata["sshs_model_hash"] = model_hash
|
||||||
@@ -200,12 +205,12 @@ def merge(args):
|
|||||||
)
|
)
|
||||||
if v2:
|
if v2:
|
||||||
# TODO read sai modelspec
|
# TODO read sai modelspec
|
||||||
print(
|
logger.warning(
|
||||||
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||||
)
|
)
|
||||||
metadata.update(sai_metadata)
|
metadata.update(sai_metadata)
|
||||||
|
|
||||||
print(f"saving model to: {args.save_to}")
|
logger.info(f"saving model to: {args.save_to}")
|
||||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -29,5 +29,7 @@ huggingface-hub==0.20.1
|
|||||||
# protobuf==3.20.3
|
# protobuf==3.20.3
|
||||||
# open clip for SDXL
|
# open clip for SDXL
|
||||||
open-clip-torch==2.20.0
|
open-clip-torch==2.20.0
|
||||||
|
# For logging
|
||||||
|
rich==13.7.0
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
-e .
|
-e .
|
||||||
|
|||||||
167
sdxl_gen_img.py
167
sdxl_gen_img.py
@@ -55,6 +55,10 @@ from networks.lora import LoRANetwork
|
|||||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||||
from library.original_unet import FlashAttentionFunction
|
from library.original_unet import FlashAttentionFunction
|
||||||
from networks.control_net_lllite import ControlNetLLLite
|
from networks.control_net_lllite import ControlNetLLLite
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# scheduler:
|
# scheduler:
|
||||||
SCHEDULER_LINEAR_START = 0.00085
|
SCHEDULER_LINEAR_START = 0.00085
|
||||||
@@ -76,12 +80,12 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|||||||
|
|
||||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
print("Enable memory efficient attention for U-Net")
|
logger.info("Enable memory efficient attention for U-Net")
|
||||||
|
|
||||||
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
|
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
|
||||||
unet.set_use_memory_efficient_attention(False, True)
|
unet.set_use_memory_efficient_attention(False, True)
|
||||||
elif xformers:
|
elif xformers:
|
||||||
print("Enable xformers for U-Net")
|
logger.info("Enable xformers for U-Net")
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -89,7 +93,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
|||||||
|
|
||||||
unet.set_use_memory_efficient_attention(True, False)
|
unet.set_use_memory_efficient_attention(True, False)
|
||||||
elif sdpa:
|
elif sdpa:
|
||||||
print("Enable SDPA for U-Net")
|
logger.info("Enable SDPA for U-Net")
|
||||||
unet.set_use_memory_efficient_attention(False, False)
|
unet.set_use_memory_efficient_attention(False, False)
|
||||||
unet.set_use_sdpa(True)
|
unet.set_use_sdpa(True)
|
||||||
|
|
||||||
@@ -106,7 +110,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
|
|||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_memory_efficient():
|
def replace_vae_attn_to_memory_efficient():
|
||||||
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
||||||
flash_func = FlashAttentionFunction
|
flash_func = FlashAttentionFunction
|
||||||
|
|
||||||
def forward_flash_attn(self, hidden_states, **kwargs):
|
def forward_flash_attn(self, hidden_states, **kwargs):
|
||||||
@@ -162,7 +166,7 @@ def replace_vae_attn_to_memory_efficient():
|
|||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_xformers():
|
def replace_vae_attn_to_xformers():
|
||||||
print("VAE: Attention.forward has been replaced to xformers")
|
logger.info("VAE: Attention.forward has been replaced to xformers")
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
def forward_xformers(self, hidden_states, **kwargs):
|
def forward_xformers(self, hidden_states, **kwargs):
|
||||||
@@ -218,7 +222,7 @@ def replace_vae_attn_to_xformers():
|
|||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_sdpa():
|
def replace_vae_attn_to_sdpa():
|
||||||
print("VAE: Attention.forward has been replaced to sdpa")
|
logger.info("VAE: Attention.forward has been replaced to sdpa")
|
||||||
|
|
||||||
def forward_sdpa(self, hidden_states, **kwargs):
|
def forward_sdpa(self, hidden_states, **kwargs):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -352,7 +356,7 @@ class PipelineLike:
|
|||||||
token_replacements = self.token_replacements_list[tokenizer_index]
|
token_replacements = self.token_replacements_list[tokenizer_index]
|
||||||
|
|
||||||
def replace_tokens(tokens):
|
def replace_tokens(tokens):
|
||||||
# print("replace_tokens", tokens, "=>", token_replacements)
|
# logger.info("replace_tokens", tokens, "=>", token_replacements)
|
||||||
if isinstance(tokens, torch.Tensor):
|
if isinstance(tokens, torch.Tensor):
|
||||||
tokens = tokens.tolist()
|
tokens = tokens.tolist()
|
||||||
|
|
||||||
@@ -444,7 +448,7 @@ class PipelineLike:
|
|||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
if not do_classifier_free_guidance and negative_scale is not None:
|
if not do_classifier_free_guidance and negative_scale is not None:
|
||||||
print(f"negative_scale is ignored if guidance scalle <= 1.0")
|
logger.info(f"negative_scale is ignored if guidance scalle <= 1.0")
|
||||||
negative_scale = None
|
negative_scale = None
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidance
|
# get unconditional embeddings for classifier free guidance
|
||||||
@@ -548,7 +552,7 @@ class PipelineLike:
|
|||||||
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
|
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
|
||||||
|
|
||||||
if init_image is not None and self.clip_vision_model is not None:
|
if init_image is not None and self.clip_vision_model is not None:
|
||||||
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
|
logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
|
||||||
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
|
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
|
||||||
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
|
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
@@ -715,7 +719,7 @@ class PipelineLike:
|
|||||||
if not enabled or ratio >= 1.0:
|
if not enabled or ratio >= 1.0:
|
||||||
continue
|
continue
|
||||||
if ratio < i / len(timesteps):
|
if ratio < i / len(timesteps):
|
||||||
print(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
|
logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
|
||||||
control_net.set_cond_image(None)
|
control_net.set_cond_image(None)
|
||||||
each_control_net_enabled[j] = False
|
each_control_net_enabled[j] = False
|
||||||
|
|
||||||
@@ -935,7 +939,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
|
|||||||
if word.strip() == "BREAK":
|
if word.strip() == "BREAK":
|
||||||
# pad until next multiple of tokenizer's max token length
|
# pad until next multiple of tokenizer's max token length
|
||||||
pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length)
|
pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length)
|
||||||
print(f"BREAK pad_len: {pad_len}")
|
logger.info(f"BREAK pad_len: {pad_len}")
|
||||||
for i in range(pad_len):
|
for i in range(pad_len):
|
||||||
# v2のときEOSをつけるべきかどうかわからないぜ
|
# v2のときEOSをつけるべきかどうかわからないぜ
|
||||||
# if i == 0:
|
# if i == 0:
|
||||||
@@ -965,7 +969,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
|
|||||||
tokens.append(text_token)
|
tokens.append(text_token)
|
||||||
weights.append(text_weight)
|
weights.append(text_weight)
|
||||||
if truncated:
|
if truncated:
|
||||||
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||||
return tokens, weights
|
return tokens, weights
|
||||||
|
|
||||||
|
|
||||||
@@ -1238,7 +1242,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
|
|||||||
elif len(count_range) == 2:
|
elif len(count_range) == 2:
|
||||||
count_range = [int(count_range[0]), int(count_range[1])]
|
count_range = [int(count_range[0]), int(count_range[1])]
|
||||||
else:
|
else:
|
||||||
print(f"invalid count range: {count_range}")
|
logger.warning(f"invalid count range: {count_range}")
|
||||||
count_range = [1, 1]
|
count_range = [1, 1]
|
||||||
if count_range[0] > count_range[1]:
|
if count_range[0] > count_range[1]:
|
||||||
count_range = [count_range[1], count_range[0]]
|
count_range = [count_range[1], count_range[0]]
|
||||||
@@ -1308,7 +1312,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
|
|||||||
|
|
||||||
|
|
||||||
# def load_clip_l14_336(dtype):
|
# def load_clip_l14_336(dtype):
|
||||||
# print(f"loading CLIP: {CLIP_ID_L14_336}")
|
# logger.info(f"loading CLIP: {CLIP_ID_L14_336}")
|
||||||
# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
|
# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
|
||||||
# return text_encoder
|
# return text_encoder
|
||||||
|
|
||||||
@@ -1378,7 +1382,7 @@ def main(args):
|
|||||||
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
|
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
|
||||||
|
|
||||||
# tokenizerを読み込む
|
# tokenizerを読み込む
|
||||||
print("loading tokenizer")
|
logger.info("loading tokenizer")
|
||||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||||
|
|
||||||
# schedulerを用意する
|
# schedulerを用意する
|
||||||
@@ -1452,7 +1456,7 @@ def main(args):
|
|||||||
self.sampler_noises = noises
|
self.sampler_noises = noises
|
||||||
|
|
||||||
def randn(self, shape, device=None, dtype=None, layout=None, generator=None):
|
def randn(self, shape, device=None, dtype=None, layout=None, generator=None):
|
||||||
# print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index)
|
# logger.info("replacing", shape, len(self.sampler_noises), self.sampler_noise_index)
|
||||||
if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises):
|
if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises):
|
||||||
noise = self.sampler_noises[self.sampler_noise_index]
|
noise = self.sampler_noises[self.sampler_noise_index]
|
||||||
if shape != noise.shape:
|
if shape != noise.shape:
|
||||||
@@ -1461,7 +1465,7 @@ def main(args):
|
|||||||
noise = None
|
noise = None
|
||||||
|
|
||||||
if noise == None:
|
if noise == None:
|
||||||
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
|
logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
|
||||||
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
|
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
|
||||||
|
|
||||||
self.sampler_noise_index += 1
|
self.sampler_noise_index += 1
|
||||||
@@ -1493,7 +1497,7 @@ def main(args):
|
|||||||
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
|
# ↓以下は結局PipeでFalseに設定されるので意味がなかった
|
||||||
# # clip_sample=Trueにする
|
# # clip_sample=Trueにする
|
||||||
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
# if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||||
# print("set clip_sample to True")
|
# logger.info("set clip_sample to True")
|
||||||
# scheduler.config.clip_sample = True
|
# scheduler.config.clip_sample = True
|
||||||
|
|
||||||
# deviceを決定する
|
# deviceを決定する
|
||||||
@@ -1522,7 +1526,7 @@ def main(args):
|
|||||||
|
|
||||||
vae_dtype = dtype
|
vae_dtype = dtype
|
||||||
if args.no_half_vae:
|
if args.no_half_vae:
|
||||||
print("set vae_dtype to float32")
|
logger.info("set vae_dtype to float32")
|
||||||
vae_dtype = torch.float32
|
vae_dtype = torch.float32
|
||||||
vae.to(vae_dtype).to(device)
|
vae.to(vae_dtype).to(device)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
@@ -1547,10 +1551,10 @@ def main(args):
|
|||||||
network_merge = args.network_merge_n_models
|
network_merge = args.network_merge_n_models
|
||||||
else:
|
else:
|
||||||
network_merge = 0
|
network_merge = 0
|
||||||
print(f"network_merge: {network_merge}")
|
logger.info(f"network_merge: {network_merge}")
|
||||||
|
|
||||||
for i, network_module in enumerate(args.network_module):
|
for i, network_module in enumerate(args.network_module):
|
||||||
print("import network module:", network_module)
|
logger.info(f"import network module: {network_module}")
|
||||||
imported_module = importlib.import_module(network_module)
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
@@ -1568,7 +1572,7 @@ def main(args):
|
|||||||
raise ValueError("No weight. Weight is required.")
|
raise ValueError("No weight. Weight is required.")
|
||||||
|
|
||||||
network_weight = args.network_weights[i]
|
network_weight = args.network_weights[i]
|
||||||
print("load network weights from:", network_weight)
|
logger.info(f"load network weights from: {network_weight}")
|
||||||
|
|
||||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
@@ -1576,7 +1580,7 @@ def main(args):
|
|||||||
with safe_open(network_weight, framework="pt") as f:
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
print(f"metadata for: {network_weight}: {metadata}")
|
logger.info(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
network, weights_sd = imported_module.create_network_from_weights(
|
network, weights_sd = imported_module.create_network_from_weights(
|
||||||
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
|
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
|
||||||
@@ -1586,20 +1590,20 @@ def main(args):
|
|||||||
|
|
||||||
mergeable = network.is_mergeable()
|
mergeable = network.is_mergeable()
|
||||||
if network_merge and not mergeable:
|
if network_merge and not mergeable:
|
||||||
print("network is not mergiable. ignore merge option.")
|
logger.warning("network is not mergiable. ignore merge option.")
|
||||||
|
|
||||||
if not mergeable or i >= network_merge:
|
if not mergeable or i >= network_merge:
|
||||||
# not merging
|
# not merging
|
||||||
network.apply_to([text_encoder1, text_encoder2], unet)
|
network.apply_to([text_encoder1, text_encoder2], unet)
|
||||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||||
print(f"weights are loaded: {info}")
|
logger.info(f"weights are loaded: {info}")
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
network.to(memory_format=torch.channels_last)
|
network.to(memory_format=torch.channels_last)
|
||||||
network.to(dtype).to(device)
|
network.to(dtype).to(device)
|
||||||
|
|
||||||
if network_pre_calc:
|
if network_pre_calc:
|
||||||
print("backup original weights")
|
logger.info("backup original weights")
|
||||||
network.backup_weights()
|
network.backup_weights()
|
||||||
|
|
||||||
networks.append(network)
|
networks.append(network)
|
||||||
@@ -1613,7 +1617,7 @@ def main(args):
|
|||||||
# upscalerの指定があれば取得する
|
# upscalerの指定があれば取得する
|
||||||
upscaler = None
|
upscaler = None
|
||||||
if args.highres_fix_upscaler:
|
if args.highres_fix_upscaler:
|
||||||
print("import upscaler module:", args.highres_fix_upscaler)
|
logger.info(f"import upscaler module: {args.highres_fix_upscaler}")
|
||||||
imported_module = importlib.import_module(args.highres_fix_upscaler)
|
imported_module = importlib.import_module(args.highres_fix_upscaler)
|
||||||
|
|
||||||
us_kwargs = {}
|
us_kwargs = {}
|
||||||
@@ -1622,7 +1626,7 @@ def main(args):
|
|||||||
key, value = net_arg.split("=")
|
key, value = net_arg.split("=")
|
||||||
us_kwargs[key] = value
|
us_kwargs[key] = value
|
||||||
|
|
||||||
print("create upscaler")
|
logger.info("create upscaler")
|
||||||
upscaler = imported_module.create_upscaler(**us_kwargs)
|
upscaler = imported_module.create_upscaler(**us_kwargs)
|
||||||
upscaler.to(dtype).to(device)
|
upscaler.to(dtype).to(device)
|
||||||
|
|
||||||
@@ -1639,7 +1643,7 @@ def main(args):
|
|||||||
# control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
# control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
||||||
if args.control_net_lllite_models:
|
if args.control_net_lllite_models:
|
||||||
for i, model_file in enumerate(args.control_net_lllite_models):
|
for i, model_file in enumerate(args.control_net_lllite_models):
|
||||||
print(f"loading ControlNet-LLLite: {model_file}")
|
logger.info(f"loading ControlNet-LLLite: {model_file}")
|
||||||
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
@@ -1670,7 +1674,7 @@ def main(args):
|
|||||||
control_nets.append((control_net, ratio))
|
control_nets.append((control_net, ratio))
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
print(f"set optimizing: channels last")
|
logger.info(f"set optimizing: channels last")
|
||||||
text_encoder1.to(memory_format=torch.channels_last)
|
text_encoder1.to(memory_format=torch.channels_last)
|
||||||
text_encoder2.to(memory_format=torch.channels_last)
|
text_encoder2.to(memory_format=torch.channels_last)
|
||||||
vae.to(memory_format=torch.channels_last)
|
vae.to(memory_format=torch.channels_last)
|
||||||
@@ -1694,7 +1698,7 @@ def main(args):
|
|||||||
args.clip_skip,
|
args.clip_skip,
|
||||||
)
|
)
|
||||||
pipe.set_control_nets(control_nets)
|
pipe.set_control_nets(control_nets)
|
||||||
print("pipeline is ready.")
|
logger.info("pipeline is ready.")
|
||||||
|
|
||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
@@ -1736,7 +1740,7 @@ def main(args):
|
|||||||
|
|
||||||
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings)
|
||||||
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
|
token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings)
|
||||||
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
|
logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
|
||||||
assert (
|
assert (
|
||||||
min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1
|
min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1
|
||||||
), f"token ids1 is not ordered"
|
), f"token ids1 is not ordered"
|
||||||
@@ -1766,7 +1770,7 @@ def main(args):
|
|||||||
|
|
||||||
# promptを取得する
|
# promptを取得する
|
||||||
if args.from_file is not None:
|
if args.from_file is not None:
|
||||||
print(f"reading prompts from {args.from_file}")
|
logger.info(f"reading prompts from {args.from_file}")
|
||||||
with open(args.from_file, "r", encoding="utf-8") as f:
|
with open(args.from_file, "r", encoding="utf-8") as f:
|
||||||
prompt_list = f.read().splitlines()
|
prompt_list = f.read().splitlines()
|
||||||
prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
|
prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
|
||||||
@@ -1795,7 +1799,7 @@ def main(args):
|
|||||||
for p in paths:
|
for p in paths:
|
||||||
image = Image.open(p)
|
image = Image.open(p)
|
||||||
if image.mode != "RGB":
|
if image.mode != "RGB":
|
||||||
print(f"convert image to RGB from {image.mode}: {p}")
|
logger.info(f"convert image to RGB from {image.mode}: {p}")
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
|
||||||
@@ -1811,14 +1815,14 @@ def main(args):
|
|||||||
return resized
|
return resized
|
||||||
|
|
||||||
if args.image_path is not None:
|
if args.image_path is not None:
|
||||||
print(f"load image for img2img: {args.image_path}")
|
logger.info(f"load image for img2img: {args.image_path}")
|
||||||
init_images = load_images(args.image_path)
|
init_images = load_images(args.image_path)
|
||||||
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
|
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
|
||||||
print(f"loaded {len(init_images)} images for img2img")
|
logger.info(f"loaded {len(init_images)} images for img2img")
|
||||||
|
|
||||||
# CLIP Vision
|
# CLIP Vision
|
||||||
if args.clip_vision_strength is not None:
|
if args.clip_vision_strength is not None:
|
||||||
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
|
logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
|
||||||
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
|
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
|
||||||
vision_model.to(device, dtype)
|
vision_model.to(device, dtype)
|
||||||
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
|
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
|
||||||
@@ -1826,22 +1830,22 @@ def main(args):
|
|||||||
pipe.clip_vision_model = vision_model
|
pipe.clip_vision_model = vision_model
|
||||||
pipe.clip_vision_processor = processor
|
pipe.clip_vision_processor = processor
|
||||||
pipe.clip_vision_strength = args.clip_vision_strength
|
pipe.clip_vision_strength = args.clip_vision_strength
|
||||||
print(f"CLIP Vision model loaded.")
|
logger.info(f"CLIP Vision model loaded.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
init_images = None
|
init_images = None
|
||||||
|
|
||||||
if args.mask_path is not None:
|
if args.mask_path is not None:
|
||||||
print(f"load mask for inpainting: {args.mask_path}")
|
logger.info(f"load mask for inpainting: {args.mask_path}")
|
||||||
mask_images = load_images(args.mask_path)
|
mask_images = load_images(args.mask_path)
|
||||||
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
|
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
|
||||||
print(f"loaded {len(mask_images)} mask images for inpainting")
|
logger.info(f"loaded {len(mask_images)} mask images for inpainting")
|
||||||
else:
|
else:
|
||||||
mask_images = None
|
mask_images = None
|
||||||
|
|
||||||
# promptがないとき、画像のPngInfoから取得する
|
# promptがないとき、画像のPngInfoから取得する
|
||||||
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
|
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
|
||||||
print("get prompts from images' metadata")
|
logger.info("get prompts from images' metadata")
|
||||||
for img in init_images:
|
for img in init_images:
|
||||||
if "prompt" in img.text:
|
if "prompt" in img.text:
|
||||||
prompt = img.text["prompt"]
|
prompt = img.text["prompt"]
|
||||||
@@ -1870,17 +1874,17 @@ def main(args):
|
|||||||
h = int(h * args.highres_fix_scale + 0.5)
|
h = int(h * args.highres_fix_scale + 0.5)
|
||||||
|
|
||||||
if init_images is not None:
|
if init_images is not None:
|
||||||
print(f"resize img2img source images to {w}*{h}")
|
logger.info(f"resize img2img source images to {w}*{h}")
|
||||||
init_images = resize_images(init_images, (w, h))
|
init_images = resize_images(init_images, (w, h))
|
||||||
if mask_images is not None:
|
if mask_images is not None:
|
||||||
print(f"resize img2img mask images to {w}*{h}")
|
logger.info(f"resize img2img mask images to {w}*{h}")
|
||||||
mask_images = resize_images(mask_images, (w, h))
|
mask_images = resize_images(mask_images, (w, h))
|
||||||
|
|
||||||
regional_network = False
|
regional_network = False
|
||||||
if networks and mask_images:
|
if networks and mask_images:
|
||||||
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
||||||
regional_network = True
|
regional_network = True
|
||||||
print("use mask as region")
|
logger.info("use mask as region")
|
||||||
|
|
||||||
size = None
|
size = None
|
||||||
for i, network in enumerate(networks):
|
for i, network in enumerate(networks):
|
||||||
@@ -1905,14 +1909,14 @@ def main(args):
|
|||||||
|
|
||||||
prev_image = None # for VGG16 guided
|
prev_image = None # for VGG16 guided
|
||||||
if args.guide_image_path is not None:
|
if args.guide_image_path is not None:
|
||||||
print(f"load image for ControlNet guidance: {args.guide_image_path}")
|
logger.info(f"load image for ControlNet guidance: {args.guide_image_path}")
|
||||||
guide_images = []
|
guide_images = []
|
||||||
for p in args.guide_image_path:
|
for p in args.guide_image_path:
|
||||||
guide_images.extend(load_images(p))
|
guide_images.extend(load_images(p))
|
||||||
|
|
||||||
print(f"loaded {len(guide_images)} guide images for guidance")
|
logger.info(f"loaded {len(guide_images)} guide images for guidance")
|
||||||
if len(guide_images) == 0:
|
if len(guide_images) == 0:
|
||||||
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
logger.warning(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
||||||
guide_images = None
|
guide_images = None
|
||||||
else:
|
else:
|
||||||
guide_images = None
|
guide_images = None
|
||||||
@@ -1938,7 +1942,7 @@ def main(args):
|
|||||||
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
||||||
|
|
||||||
for gen_iter in range(args.n_iter):
|
for gen_iter in range(args.n_iter):
|
||||||
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
logger.info(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||||
iter_seed = random.randint(0, 0x7FFFFFFF)
|
iter_seed = random.randint(0, 0x7FFFFFFF)
|
||||||
|
|
||||||
# バッチ処理の関数
|
# バッチ処理の関数
|
||||||
@@ -1950,7 +1954,7 @@ def main(args):
|
|||||||
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
||||||
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
|
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
|
||||||
|
|
||||||
print("process 1st stage")
|
logger.info("process 1st stage")
|
||||||
batch_1st = []
|
batch_1st = []
|
||||||
for _, base, ext in batch:
|
for _, base, ext in batch:
|
||||||
|
|
||||||
@@ -1995,7 +1999,7 @@ def main(args):
|
|||||||
images_1st = process_batch(batch_1st, True, True)
|
images_1st = process_batch(batch_1st, True, True)
|
||||||
|
|
||||||
# 2nd stageのバッチを作成して以下処理する
|
# 2nd stageのバッチを作成して以下処理する
|
||||||
print("process 2nd stage")
|
logger.info("process 2nd stage")
|
||||||
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
|
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
|
||||||
|
|
||||||
if upscaler:
|
if upscaler:
|
||||||
@@ -2161,7 +2165,7 @@ def main(args):
|
|||||||
n.restore_weights()
|
n.restore_weights()
|
||||||
for n in networks:
|
for n in networks:
|
||||||
n.pre_calculation()
|
n.pre_calculation()
|
||||||
print("pre-calculation... done")
|
logger.info("pre-calculation... done")
|
||||||
|
|
||||||
images = pipe(
|
images = pipe(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -2240,7 +2244,7 @@ def main(args):
|
|||||||
cv2.waitKey()
|
cv2.waitKey()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
|
logger.error("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
@@ -2253,7 +2257,8 @@ def main(args):
|
|||||||
# interactive
|
# interactive
|
||||||
valid = False
|
valid = False
|
||||||
while not valid:
|
while not valid:
|
||||||
print("\nType prompt:")
|
logger.info("")
|
||||||
|
logger.info("Type prompt:")
|
||||||
try:
|
try:
|
||||||
raw_prompt = input()
|
raw_prompt = input()
|
||||||
except EOFError:
|
except EOFError:
|
||||||
@@ -2302,74 +2307,74 @@ def main(args):
|
|||||||
|
|
||||||
prompt_args = raw_prompt.strip().split(" --")
|
prompt_args = raw_prompt.strip().split(" --")
|
||||||
prompt = prompt_args[0]
|
prompt = prompt_args[0]
|
||||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||||
|
|
||||||
for parg in prompt_args[1:]:
|
for parg in prompt_args[1:]:
|
||||||
try:
|
try:
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
width = int(m.group(1))
|
width = int(m.group(1))
|
||||||
print(f"width: {width}")
|
logger.info(f"width: {width}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
height = int(m.group(1))
|
height = int(m.group(1))
|
||||||
print(f"height: {height}")
|
logger.info(f"height: {height}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
original_width = int(m.group(1))
|
original_width = int(m.group(1))
|
||||||
print(f"original width: {original_width}")
|
logger.info(f"original width: {original_width}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
original_height = int(m.group(1))
|
original_height = int(m.group(1))
|
||||||
print(f"original height: {original_height}")
|
logger.info(f"original height: {original_height}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
original_width_negative = int(m.group(1))
|
original_width_negative = int(m.group(1))
|
||||||
print(f"original width negative: {original_width_negative}")
|
logger.info(f"original width negative: {original_width_negative}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
original_height_negative = int(m.group(1))
|
original_height_negative = int(m.group(1))
|
||||||
print(f"original height negative: {original_height_negative}")
|
logger.info(f"original height negative: {original_height_negative}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
crop_top = int(m.group(1))
|
crop_top = int(m.group(1))
|
||||||
print(f"crop top: {crop_top}")
|
logger.info(f"crop top: {crop_top}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
crop_left = int(m.group(1))
|
crop_left = int(m.group(1))
|
||||||
print(f"crop left: {crop_left}")
|
logger.info(f"crop left: {crop_left}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||||
if m: # steps
|
if m: # steps
|
||||||
steps = max(1, min(1000, int(m.group(1))))
|
steps = max(1, min(1000, int(m.group(1))))
|
||||||
print(f"steps: {steps}")
|
logger.info(f"steps: {steps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||||
if m: # seed
|
if m: # seed
|
||||||
seeds = [int(d) for d in m.group(1).split(",")]
|
seeds = [int(d) for d in m.group(1).split(",")]
|
||||||
print(f"seeds: {seeds}")
|
logger.info(f"seeds: {seeds}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # scale
|
if m: # scale
|
||||||
scale = float(m.group(1))
|
scale = float(m.group(1))
|
||||||
print(f"scale: {scale}")
|
logger.info(f"scale: {scale}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
||||||
@@ -2378,25 +2383,25 @@ def main(args):
|
|||||||
negative_scale = None
|
negative_scale = None
|
||||||
else:
|
else:
|
||||||
negative_scale = float(m.group(1))
|
negative_scale = float(m.group(1))
|
||||||
print(f"negative scale: {negative_scale}")
|
logger.info(f"negative scale: {negative_scale}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # strength
|
if m: # strength
|
||||||
strength = float(m.group(1))
|
strength = float(m.group(1))
|
||||||
print(f"strength: {strength}")
|
logger.info(f"strength: {strength}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||||
if m: # negative prompt
|
if m: # negative prompt
|
||||||
negative_prompt = m.group(1)
|
negative_prompt = m.group(1)
|
||||||
print(f"negative prompt: {negative_prompt}")
|
logger.info(f"negative prompt: {negative_prompt}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
||||||
if m: # clip prompt
|
if m: # clip prompt
|
||||||
clip_prompt = m.group(1)
|
clip_prompt = m.group(1)
|
||||||
print(f"clip prompt: {clip_prompt}")
|
logger.info(f"clip prompt: {clip_prompt}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||||
@@ -2404,47 +2409,47 @@ def main(args):
|
|||||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||||
while len(network_muls) < len(networks):
|
while len(network_muls) < len(networks):
|
||||||
network_muls.append(network_muls[-1])
|
network_muls.append(network_muls[-1])
|
||||||
print(f"network mul: {network_muls}")
|
logger.info(f"network mul: {network_muls}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Deep Shrink
|
# Deep Shrink
|
||||||
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink depth 1
|
if m: # deep shrink depth 1
|
||||||
ds_depth_1 = int(m.group(1))
|
ds_depth_1 = int(m.group(1))
|
||||||
print(f"deep shrink depth 1: {ds_depth_1}")
|
logger.info(f"deep shrink depth 1: {ds_depth_1}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink timesteps 1
|
if m: # deep shrink timesteps 1
|
||||||
ds_timesteps_1 = int(m.group(1))
|
ds_timesteps_1 = int(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink depth 2
|
if m: # deep shrink depth 2
|
||||||
ds_depth_2 = int(m.group(1))
|
ds_depth_2 = int(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink depth 2: {ds_depth_2}")
|
logger.info(f"deep shrink depth 2: {ds_depth_2}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink timesteps 2
|
if m: # deep shrink timesteps 2
|
||||||
ds_timesteps_2 = int(m.group(1))
|
ds_timesteps_2 = int(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink ratio
|
if m: # deep shrink ratio
|
||||||
ds_ratio = float(m.group(1))
|
ds_ratio = float(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink ratio: {ds_ratio}")
|
logger.info(f"deep shrink ratio: {ds_ratio}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
print(ex)
|
logger.error(f"{ex}")
|
||||||
|
|
||||||
# override Deep Shrink
|
# override Deep Shrink
|
||||||
if ds_depth_1 is not None:
|
if ds_depth_1 is not None:
|
||||||
@@ -2462,7 +2467,7 @@ def main(args):
|
|||||||
if len(predefined_seeds) > 0:
|
if len(predefined_seeds) > 0:
|
||||||
seed = predefined_seeds.pop(0)
|
seed = predefined_seeds.pop(0)
|
||||||
else:
|
else:
|
||||||
print("predefined seeds are exhausted")
|
logger.error("predefined seeds are exhausted")
|
||||||
seed = None
|
seed = None
|
||||||
elif args.iter_same_seed:
|
elif args.iter_same_seed:
|
||||||
seeds = iter_seed
|
seeds = iter_seed
|
||||||
@@ -2472,7 +2477,7 @@ def main(args):
|
|||||||
if seed is None:
|
if seed is None:
|
||||||
seed = random.randint(0, 0x7FFFFFFF)
|
seed = random.randint(0, 0x7FFFFFFF)
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
print(f"seed: {seed}")
|
logger.info(f"seed: {seed}")
|
||||||
|
|
||||||
# prepare init image, guide image and mask
|
# prepare init image, guide image and mask
|
||||||
init_image = mask_image = guide_image = None
|
init_image = mask_image = guide_image = None
|
||||||
@@ -2488,7 +2493,7 @@ def main(args):
|
|||||||
width = width - width % 32
|
width = width - width % 32
|
||||||
height = height - height % 32
|
height = height - height % 32
|
||||||
if width != init_image.size[0] or height != init_image.size[1]:
|
if width != init_image.size[0] or height != init_image.size[1]:
|
||||||
print(
|
logger.warning(
|
||||||
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2548,7 +2553,7 @@ def main(args):
|
|||||||
process_batch(batch_data, highres_fix)
|
process_batch(batch_data, highres_fix)
|
||||||
batch_data.clear()
|
batch_data.clear()
|
||||||
|
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ from safetensors.torch import load_file
|
|||||||
|
|
||||||
from library import model_util, sdxl_model_util
|
from library import model_util, sdxl_model_util
|
||||||
import networks.lora as lora
|
import networks.lora as lora
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
||||||
# scheduler: The settings around here seem to be the same as SD1/2
|
# scheduler: The settings around here seem to be the same as SD1/2
|
||||||
@@ -140,7 +144,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
vae_dtype = DTYPE
|
vae_dtype = DTYPE
|
||||||
if DTYPE == torch.float16:
|
if DTYPE == torch.float16:
|
||||||
print("use float32 for vae")
|
logger.info("use float32 for vae")
|
||||||
vae_dtype = torch.float32
|
vae_dtype = torch.float32
|
||||||
vae.to(DEVICE, dtype=vae_dtype)
|
vae.to(DEVICE, dtype=vae_dtype)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
@@ -187,7 +191,7 @@ if __name__ == "__main__":
|
|||||||
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
||||||
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
||||||
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
||||||
# print("emb1", emb1.shape)
|
# logger.info("emb1", emb1.shape)
|
||||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
||||||
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
||||||
|
|
||||||
@@ -217,7 +221,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||||
# print("hidden_states2", text_embedding2_penu.shape)
|
# logger.info("hidden_states2", text_embedding2_penu.shape)
|
||||||
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
|
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
|
||||||
|
|
||||||
# 連結して終了 concat and finish
|
# 連結して終了 concat and finish
|
||||||
@@ -226,7 +230,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# cond
|
# cond
|
||||||
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
|
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
|
||||||
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
# logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
||||||
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
||||||
|
|
||||||
# uncond
|
# uncond
|
||||||
@@ -323,4 +327,4 @@ if __name__ == "__main__":
|
|||||||
seed = int(seed)
|
seed = int(seed)
|
||||||
generate_image(prompt, prompt2, negative_prompt, seed)
|
generate_image(prompt, prompt2, negative_prompt, seed)
|
||||||
|
|
||||||
print("Done!")
|
logger.info("Done!")
|
||||||
|
|||||||
@@ -20,6 +20,10 @@ from diffusers import DDPMScheduler
|
|||||||
from library import sdxl_model_util
|
from library import sdxl_model_util
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
import library.config_util as config_util
|
import library.config_util as config_util
|
||||||
import library.sdxl_train_util as sdxl_train_util
|
import library.sdxl_train_util as sdxl_train_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
@@ -117,18 +121,18 @@ def train(args):
|
|||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "in_json"]
|
ignored = ["train_data_dir", "in_json"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if use_dreambooth_method:
|
if use_dreambooth_method:
|
||||||
print("Using DreamBooth method.")
|
logger.info("Using DreamBooth method.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -139,7 +143,7 @@ def train(args):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print("Training with captions.")
|
logger.info("Training with captions.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -169,7 +173,7 @@ def train(args):
|
|||||||
train_util.debug_dataset(train_dataset_group, True)
|
train_util.debug_dataset(train_dataset_group, True)
|
||||||
return
|
return
|
||||||
if len(train_dataset_group) == 0:
|
if len(train_dataset_group) == 0:
|
||||||
print(
|
logger.error(
|
||||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -185,7 +189,7 @@ def train(args):
|
|||||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
@@ -537,7 +541,7 @@ def train(args):
|
|||||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
# print("text encoder outputs verified")
|
# logger.info("text encoder outputs verified")
|
||||||
|
|
||||||
# get size embeddings
|
# get size embeddings
|
||||||
orig_size = batch["original_sizes_hw"]
|
orig_size = batch["original_sizes_hw"]
|
||||||
@@ -724,7 +728,7 @@ def train(args):
|
|||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
)
|
)
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -45,7 +45,10 @@ from library.custom_train_functions import (
|
|||||||
apply_debiased_estimation,
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||||
@@ -78,11 +81,11 @@ def train(args):
|
|||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
@@ -114,7 +117,7 @@ def train(args):
|
|||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
return
|
return
|
||||||
if len(train_dataset_group) == 0:
|
if len(train_dataset_group) == 0:
|
||||||
print(
|
logger.error(
|
||||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -124,7 +127,7 @@ def train(args):
|
|||||||
train_dataset_group.is_latent_cacheable()
|
train_dataset_group.is_latent_cacheable()
|
||||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
else:
|
else:
|
||||||
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
|
logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
|
||||||
|
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
assert (
|
assert (
|
||||||
@@ -132,7 +135,7 @@ def train(args):
|
|||||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
@@ -231,8 +234,8 @@ def train(args):
|
|||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
trainable_params = list(unet.prepare_params())
|
trainable_params = list(unet.prepare_params())
|
||||||
print(f"trainable params count: {len(trainable_params)}")
|
logger.info(f"trainable params count: {len(trainable_params)}")
|
||||||
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||||
|
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
@@ -324,7 +327,7 @@ def train(args):
|
|||||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
@@ -548,7 +551,7 @@ def train(args):
|
|||||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||||
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
|
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
|
||||||
|
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -41,7 +41,10 @@ from library.custom_train_functions import (
|
|||||||
apply_debiased_estimation,
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
import networks.control_net_lllite as control_net_lllite
|
import networks.control_net_lllite as control_net_lllite
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||||
@@ -74,11 +77,11 @@ def train(args):
|
|||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
@@ -110,7 +113,7 @@ def train(args):
|
|||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
return
|
return
|
||||||
if len(train_dataset_group) == 0:
|
if len(train_dataset_group) == 0:
|
||||||
print(
|
logger.error(
|
||||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -120,7 +123,7 @@ def train(args):
|
|||||||
train_dataset_group.is_latent_cacheable()
|
train_dataset_group.is_latent_cacheable()
|
||||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
else:
|
else:
|
||||||
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
|
logger.warning("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
|
||||||
|
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
assert (
|
assert (
|
||||||
@@ -128,7 +131,7 @@ def train(args):
|
|||||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
@@ -199,8 +202,8 @@ def train(args):
|
|||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
trainable_params = list(network.prepare_optimizer_params())
|
trainable_params = list(network.prepare_optimizer_params())
|
||||||
print(f"trainable params count: {len(trainable_params)}")
|
logger.info(f"trainable params count: {len(trainable_params)}")
|
||||||
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||||
|
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
@@ -297,7 +300,7 @@ def train(args):
|
|||||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
@@ -516,7 +519,7 @@ def train(args):
|
|||||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||||
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
||||||
|
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -7,7 +7,10 @@ init_ipex()
|
|||||||
|
|
||||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||||
import train_network
|
import train_network
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -60,7 +63,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
if not args.lowram:
|
if not args.lowram:
|
||||||
# メモリ消費を減らす
|
# メモリ消費を減らす
|
||||||
print("move vae and unet to cpu to save memory")
|
logger.info("move vae and unet to cpu to save memory")
|
||||||
org_vae_device = vae.device
|
org_vae_device = vae.device
|
||||||
org_unet_device = unet.device
|
org_unet_device = unet.device
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
@@ -85,7 +88,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if not args.lowram:
|
if not args.lowram:
|
||||||
print("move vae and unet back to original device")
|
logger.info("move vae and unet back to original device")
|
||||||
vae.to(org_vae_device)
|
vae.to(org_vae_device)
|
||||||
unet.to(org_unet_device)
|
unet.to(org_unet_device)
|
||||||
else:
|
else:
|
||||||
@@ -143,7 +146,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||||
# print("text encoder outputs verified")
|
# logger.info("text encoder outputs verified")
|
||||||
|
|
||||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,10 @@ from library.config_util import (
|
|||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
@@ -41,18 +44,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "in_json"]
|
ignored = ["train_data_dir", "in_json"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if use_dreambooth_method:
|
if use_dreambooth_method:
|
||||||
print("Using DreamBooth method.")
|
logger.info("Using DreamBooth method.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -63,7 +66,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print("Training with captions.")
|
logger.info("Training with captions.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -90,7 +93,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
@@ -98,7 +101,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
print("load model")
|
logger.info("load model")
|
||||||
if args.sdxl:
|
if args.sdxl:
|
||||||
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||||
else:
|
else:
|
||||||
@@ -152,7 +155,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
|
|
||||||
if args.skip_existing:
|
if args.skip_existing:
|
||||||
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
||||||
print(f"Skipping {image_info.latents_npz} because it already exists.")
|
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
image_infos.append(image_info)
|
image_infos.append(image_info)
|
||||||
|
|||||||
@@ -16,7 +16,10 @@ from library.config_util import (
|
|||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
@@ -48,18 +51,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "in_json"]
|
ignored = ["train_data_dir", "in_json"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if use_dreambooth_method:
|
if use_dreambooth_method:
|
||||||
print("Using DreamBooth method.")
|
logger.info("Using DreamBooth method.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -70,7 +73,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print("Training with captions.")
|
logger.info("Training with captions.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -95,14 +98,14 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
print("load model")
|
logger.info("load model")
|
||||||
if args.sdxl:
|
if args.sdxl:
|
||||||
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||||
text_encoders = [text_encoder1, text_encoder2]
|
text_encoders = [text_encoder1, text_encoder2]
|
||||||
@@ -147,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
|
|
||||||
if args.skip_existing:
|
if args.skip_existing:
|
||||||
if os.path.exists(image_info.text_encoder_outputs_npz):
|
if os.path.exists(image_info.text_encoder_outputs_npz):
|
||||||
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
image_info.input_ids1 = input_ids1
|
image_info.input_ids1 = input_ids1
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def canny(args):
|
def canny(args):
|
||||||
img = cv2.imread(args.input)
|
img = cv2.imread(args.input)
|
||||||
@@ -10,7 +14,7 @@ def canny(args):
|
|||||||
# canny_img = 255 - canny_img
|
# canny_img = 255 - canny_img
|
||||||
|
|
||||||
cv2.imwrite(args.output, canny_img)
|
cv2.imwrite(args.output, canny_img)
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ import torch
|
|||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def convert(args):
|
def convert(args):
|
||||||
# 引数を確認する
|
# 引数を確認する
|
||||||
@@ -30,7 +33,7 @@ def convert(args):
|
|||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||||
print(f"loading {msg}: {args.model_to_load}")
|
logger.info(f"loading {msg}: {args.model_to_load}")
|
||||||
|
|
||||||
if is_load_ckpt:
|
if is_load_ckpt:
|
||||||
v2_model = args.v2
|
v2_model = args.v2
|
||||||
@@ -48,13 +51,13 @@ def convert(args):
|
|||||||
if args.v1 == args.v2:
|
if args.v1 == args.v2:
|
||||||
# 自動判定する
|
# 自動判定する
|
||||||
v2_model = unet.config.cross_attention_dim == 1024
|
v2_model = unet.config.cross_attention_dim == 1024
|
||||||
print("checking model version: model is " + ("v2" if v2_model else "v1"))
|
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
|
||||||
else:
|
else:
|
||||||
v2_model = not args.v1
|
v2_model = not args.v1
|
||||||
|
|
||||||
# 変換して保存する
|
# 変換して保存する
|
||||||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||||
print(f"converting and saving as {msg}: {args.model_to_save}")
|
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
|
||||||
|
|
||||||
if is_save_ckpt:
|
if is_save_ckpt:
|
||||||
original_model = args.model_to_load if is_load_ckpt else None
|
original_model = args.model_to_load if is_load_ckpt else None
|
||||||
@@ -70,15 +73,15 @@ def convert(args):
|
|||||||
save_dtype=save_dtype,
|
save_dtype=save_dtype,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
)
|
)
|
||||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
logger.info(f"model saved. total converted state_dict keys: {key_count}")
|
||||||
else:
|
else:
|
||||||
print(
|
logger.info(
|
||||||
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
||||||
)
|
)
|
||||||
model_util.save_diffusers_checkpoint(
|
model_util.save_diffusers_checkpoint(
|
||||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||||
)
|
)
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -15,6 +15,10 @@ import os
|
|||||||
from anime_face_detector import create_detector
|
from anime_face_detector import create_detector
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
KP_REYE = 11
|
KP_REYE = 11
|
||||||
KP_LEYE = 19
|
KP_LEYE = 19
|
||||||
@@ -24,7 +28,7 @@ SCORE_THRES = 0.90
|
|||||||
|
|
||||||
def detect_faces(detector, image, min_size):
|
def detect_faces(detector, image, min_size):
|
||||||
preds = detector(image) # bgr
|
preds = detector(image) # bgr
|
||||||
# print(len(preds))
|
# logger.info(len(preds))
|
||||||
|
|
||||||
faces = []
|
faces = []
|
||||||
for pred in preds:
|
for pred in preds:
|
||||||
@@ -78,7 +82,7 @@ def process(args):
|
|||||||
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
||||||
|
|
||||||
# アニメ顔検出モデルを読み込む
|
# アニメ顔検出モデルを読み込む
|
||||||
print("loading face detector.")
|
logger.info("loading face detector.")
|
||||||
detector = create_detector('yolov3')
|
detector = create_detector('yolov3')
|
||||||
|
|
||||||
# cropの引数を解析する
|
# cropの引数を解析する
|
||||||
@@ -97,7 +101,7 @@ def process(args):
|
|||||||
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
||||||
|
|
||||||
# 画像を処理する
|
# 画像を処理する
|
||||||
print("processing.")
|
logger.info("processing.")
|
||||||
output_extension = ".png"
|
output_extension = ".png"
|
||||||
|
|
||||||
os.makedirs(args.dst_dir, exist_ok=True)
|
os.makedirs(args.dst_dir, exist_ok=True)
|
||||||
@@ -111,7 +115,7 @@ def process(args):
|
|||||||
if len(image.shape) == 2:
|
if len(image.shape) == 2:
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||||
if image.shape[2] == 4:
|
if image.shape[2] == 4:
|
||||||
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
||||||
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
||||||
|
|
||||||
h, w = image.shape[:2]
|
h, w = image.shape[:2]
|
||||||
@@ -144,11 +148,11 @@ def process(args):
|
|||||||
# 顔サイズを基準にリサイズする
|
# 顔サイズを基準にリサイズする
|
||||||
scale = args.resize_face_size / face_size
|
scale = args.resize_face_size / face_size
|
||||||
if scale < cur_crop_width / w:
|
if scale < cur_crop_width / w:
|
||||||
print(
|
logger.warning(
|
||||||
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||||
scale = cur_crop_width / w
|
scale = cur_crop_width / w
|
||||||
if scale < cur_crop_height / h:
|
if scale < cur_crop_height / h:
|
||||||
print(
|
logger.warning(
|
||||||
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||||
scale = cur_crop_height / h
|
scale = cur_crop_height / h
|
||||||
elif crop_h_ratio is not None:
|
elif crop_h_ratio is not None:
|
||||||
@@ -157,10 +161,10 @@ def process(args):
|
|||||||
else:
|
else:
|
||||||
# 切り出しサイズ指定あり
|
# 切り出しサイズ指定あり
|
||||||
if w < cur_crop_width:
|
if w < cur_crop_width:
|
||||||
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
||||||
scale = cur_crop_width / w
|
scale = cur_crop_width / w
|
||||||
if h < cur_crop_height:
|
if h < cur_crop_height:
|
||||||
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
||||||
scale = cur_crop_height / h
|
scale = cur_crop_height / h
|
||||||
if args.resize_fit:
|
if args.resize_fit:
|
||||||
scale = max(cur_crop_width / w, cur_crop_height / h)
|
scale = max(cur_crop_width / w, cur_crop_height / h)
|
||||||
@@ -198,7 +202,7 @@ def process(args):
|
|||||||
face_img = face_img[y:y + cur_crop_height]
|
face_img = face_img[y:y + cur_crop_height]
|
||||||
|
|
||||||
# # debug
|
# # debug
|
||||||
# print(path, cx, cy, angle)
|
# logger.info(path, cx, cy, angle)
|
||||||
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
||||||
# cv2.imshow("image", crp)
|
# cv2.imshow("image", crp)
|
||||||
# if cv2.waitKey() == 27:
|
# if cv2.waitKey() == 27:
|
||||||
|
|||||||
@@ -14,7 +14,10 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
||||||
@@ -216,7 +219,7 @@ class Upscaler(nn.Module):
|
|||||||
upsampled_images = upsampled_images / 127.5 - 1.0
|
upsampled_images = upsampled_images / 127.5 - 1.0
|
||||||
|
|
||||||
# convert upsample images to latents with batch size
|
# convert upsample images to latents with batch size
|
||||||
# print("Encoding upsampled (LANCZOS4) images...")
|
# logger.info("Encoding upsampled (LANCZOS4) images...")
|
||||||
upsampled_latents = []
|
upsampled_latents = []
|
||||||
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
||||||
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
||||||
@@ -227,7 +230,7 @@ class Upscaler(nn.Module):
|
|||||||
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
||||||
|
|
||||||
# upscale (refine) latents with this model with batch size
|
# upscale (refine) latents with this model with batch size
|
||||||
print("Upscaling latents...")
|
logger.info("Upscaling latents...")
|
||||||
upscaled_latents = []
|
upscaled_latents = []
|
||||||
for i in range(0, upsampled_latents.shape[0], batch_size):
|
for i in range(0, upsampled_latents.shape[0], batch_size):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -242,7 +245,7 @@ def create_upscaler(**kwargs):
|
|||||||
weights = kwargs["weights"]
|
weights = kwargs["weights"]
|
||||||
model = Upscaler()
|
model = Upscaler()
|
||||||
|
|
||||||
print(f"Loading weights from {weights}...")
|
logger.info(f"Loading weights from {weights}...")
|
||||||
if os.path.splitext(weights)[1] == ".safetensors":
|
if os.path.splitext(weights)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
@@ -261,14 +264,14 @@ def upscale_images(args: argparse.Namespace):
|
|||||||
|
|
||||||
# load VAE with Diffusers
|
# load VAE with Diffusers
|
||||||
assert args.vae_path is not None, "VAE path is required"
|
assert args.vae_path is not None, "VAE path is required"
|
||||||
print(f"Loading VAE from {args.vae_path}...")
|
logger.info(f"Loading VAE from {args.vae_path}...")
|
||||||
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
||||||
vae.to(DEVICE, dtype=us_dtype)
|
vae.to(DEVICE, dtype=us_dtype)
|
||||||
|
|
||||||
# prepare model
|
# prepare model
|
||||||
print("Preparing model...")
|
logger.info("Preparing model...")
|
||||||
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
||||||
# print("Loading weights from", args.weights)
|
# logger.info("Loading weights from", args.weights)
|
||||||
# upscaler.load_state_dict(torch.load(args.weights))
|
# upscaler.load_state_dict(torch.load(args.weights))
|
||||||
upscaler.eval()
|
upscaler.eval()
|
||||||
upscaler.to(DEVICE, dtype=us_dtype)
|
upscaler.to(DEVICE, dtype=us_dtype)
|
||||||
@@ -303,14 +306,14 @@ def upscale_images(args: argparse.Namespace):
|
|||||||
image_debug.save(dest_file_name)
|
image_debug.save(dest_file_name)
|
||||||
|
|
||||||
# upscale
|
# upscale
|
||||||
print("Upscaling...")
|
logger.info("Upscaling...")
|
||||||
upscaled_latents = upscaler.upscale(
|
upscaled_latents = upscaler.upscale(
|
||||||
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
||||||
)
|
)
|
||||||
upscaled_latents /= 0.18215
|
upscaled_latents /= 0.18215
|
||||||
|
|
||||||
# decode with batch
|
# decode with batch
|
||||||
print("Decoding...")
|
logger.info("Decoding...")
|
||||||
upscaled_images = []
|
upscaled_images = []
|
||||||
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ import torch
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def is_unet_key(key):
|
def is_unet_key(key):
|
||||||
# VAE or TextEncoder, the last one is for SDXL
|
# VAE or TextEncoder, the last one is for SDXL
|
||||||
@@ -45,10 +48,10 @@ def merge(args):
|
|||||||
# check if all models are safetensors
|
# check if all models are safetensors
|
||||||
for model in args.models:
|
for model in args.models:
|
||||||
if not model.endswith("safetensors"):
|
if not model.endswith("safetensors"):
|
||||||
print(f"Model {model} is not a safetensors model")
|
logger.info(f"Model {model} is not a safetensors model")
|
||||||
exit()
|
exit()
|
||||||
if not os.path.isfile(model):
|
if not os.path.isfile(model):
|
||||||
print(f"Model {model} does not exist")
|
logger.info(f"Model {model} does not exist")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
|
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
|
||||||
@@ -65,7 +68,7 @@ def merge(args):
|
|||||||
|
|
||||||
if merged_sd is None:
|
if merged_sd is None:
|
||||||
# load first model
|
# load first model
|
||||||
print(f"Loading model {model}, ratio = {ratio}...")
|
logger.info(f"Loading model {model}, ratio = {ratio}...")
|
||||||
merged_sd = {}
|
merged_sd = {}
|
||||||
with safe_open(model, framework="pt", device=args.device) as f:
|
with safe_open(model, framework="pt", device=args.device) as f:
|
||||||
for key in tqdm(f.keys()):
|
for key in tqdm(f.keys()):
|
||||||
@@ -81,11 +84,11 @@ def merge(args):
|
|||||||
value = ratio * value.to(dtype) # first model's value * ratio
|
value = ratio * value.to(dtype) # first model's value * ratio
|
||||||
merged_sd[key] = value
|
merged_sd[key] = value
|
||||||
|
|
||||||
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
|
logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# load other models
|
# load other models
|
||||||
print(f"Loading model {model}, ratio = {ratio}...")
|
logger.info(f"Loading model {model}, ratio = {ratio}...")
|
||||||
|
|
||||||
with safe_open(model, framework="pt", device=args.device) as f:
|
with safe_open(model, framework="pt", device=args.device) as f:
|
||||||
model_keys = f.keys()
|
model_keys = f.keys()
|
||||||
@@ -93,7 +96,7 @@ def merge(args):
|
|||||||
_, new_key = replace_text_encoder_key(key)
|
_, new_key = replace_text_encoder_key(key)
|
||||||
if new_key not in merged_sd:
|
if new_key not in merged_sd:
|
||||||
if args.show_skipped and new_key not in first_model_keys:
|
if args.show_skipped and new_key not in first_model_keys:
|
||||||
print(f"Skip: {new_key}")
|
logger.info(f"Skip: {new_key}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
value = f.get_tensor(key)
|
value = f.get_tensor(key)
|
||||||
@@ -104,7 +107,7 @@ def merge(args):
|
|||||||
for key in merged_sd.keys():
|
for key in merged_sd.keys():
|
||||||
if key in model_keys:
|
if key in model_keys:
|
||||||
continue
|
continue
|
||||||
print(f"Key {key} not in model {model}, use first model's value")
|
logger.warning(f"Key {key} not in model {model}, use first model's value")
|
||||||
if key in supplementary_key_ratios:
|
if key in supplementary_key_ratios:
|
||||||
supplementary_key_ratios[key] += ratio
|
supplementary_key_ratios[key] += ratio
|
||||||
else:
|
else:
|
||||||
@@ -112,7 +115,7 @@ def merge(args):
|
|||||||
|
|
||||||
# add supplementary keys' value (including VAE and TextEncoder)
|
# add supplementary keys' value (including VAE and TextEncoder)
|
||||||
if len(supplementary_key_ratios) > 0:
|
if len(supplementary_key_ratios) > 0:
|
||||||
print("add first model's value")
|
logger.info("add first model's value")
|
||||||
with safe_open(args.models[0], framework="pt", device=args.device) as f:
|
with safe_open(args.models[0], framework="pt", device=args.device) as f:
|
||||||
for key in tqdm(f.keys()):
|
for key in tqdm(f.keys()):
|
||||||
_, new_key = replace_text_encoder_key(key)
|
_, new_key = replace_text_encoder_key(key)
|
||||||
@@ -120,7 +123,7 @@ def merge(args):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if is_unet_key(new_key): # not VAE or TextEncoder
|
if is_unet_key(new_key): # not VAE or TextEncoder
|
||||||
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
|
logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
|
||||||
|
|
||||||
value = f.get_tensor(key) # original key
|
value = f.get_tensor(key) # original key
|
||||||
|
|
||||||
@@ -134,7 +137,7 @@ def merge(args):
|
|||||||
if not output_file.endswith(".safetensors"):
|
if not output_file.endswith(".safetensors"):
|
||||||
output_file = output_file + ".safetensors"
|
output_file = output_file + ".safetensors"
|
||||||
|
|
||||||
print(f"Saving to {output_file}...")
|
logger.info(f"Saving to {output_file}...")
|
||||||
|
|
||||||
# convert to save_dtype
|
# convert to save_dtype
|
||||||
for k in merged_sd.keys():
|
for k in merged_sd.keys():
|
||||||
@@ -142,7 +145,7 @@ def merge(args):
|
|||||||
|
|
||||||
save_file(merged_sd, output_file)
|
save_file(merged_sd, output_file)
|
||||||
|
|
||||||
print("Done!")
|
logger.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -7,7 +7,10 @@ from safetensors.torch import load_file
|
|||||||
from library.original_unet import UNet2DConditionModel, SampleOutput
|
from library.original_unet import UNet2DConditionModel, SampleOutput
|
||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class ControlNetInfo(NamedTuple):
|
class ControlNetInfo(NamedTuple):
|
||||||
unet: Any
|
unet: Any
|
||||||
@@ -51,7 +54,7 @@ def load_control_net(v2, unet, model):
|
|||||||
|
|
||||||
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
|
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
|
||||||
# state dictを読み込む
|
# state dictを読み込む
|
||||||
print(f"ControlNet: loading control SD model : {model}")
|
logger.info(f"ControlNet: loading control SD model : {model}")
|
||||||
|
|
||||||
if model_util.is_safetensors(model):
|
if model_util.is_safetensors(model):
|
||||||
ctrl_sd_sd = load_file(model)
|
ctrl_sd_sd = load_file(model)
|
||||||
@@ -61,7 +64,7 @@ def load_control_net(v2, unet, model):
|
|||||||
|
|
||||||
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
||||||
is_difference = "difference" in ctrl_sd_sd
|
is_difference = "difference" in ctrl_sd_sd
|
||||||
print("ControlNet: loading difference:", is_difference)
|
logger.info(f"ControlNet: loading difference: {is_difference}")
|
||||||
|
|
||||||
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
||||||
# またTransfer Controlの元weightとなる
|
# またTransfer Controlの元weightとなる
|
||||||
@@ -89,13 +92,13 @@ def load_control_net(v2, unet, model):
|
|||||||
# ControlNetのU-Netを作成する
|
# ControlNetのU-Netを作成する
|
||||||
ctrl_unet = UNet2DConditionModel(**unet_config)
|
ctrl_unet = UNet2DConditionModel(**unet_config)
|
||||||
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
||||||
print("ControlNet: loading Control U-Net:", info)
|
logger.info(f"ControlNet: loading Control U-Net: {info}")
|
||||||
|
|
||||||
# U-Net以外のControlNetを作成する
|
# U-Net以外のControlNetを作成する
|
||||||
# TODO support middle only
|
# TODO support middle only
|
||||||
ctrl_net = ControlNet()
|
ctrl_net = ControlNet()
|
||||||
info = ctrl_net.load_state_dict(zero_conv_sd)
|
info = ctrl_net.load_state_dict(zero_conv_sd)
|
||||||
print("ControlNet: loading ControlNet:", info)
|
logger.info("ControlNet: loading ControlNet: {info}")
|
||||||
|
|
||||||
ctrl_unet.to(unet.device, dtype=unet.dtype)
|
ctrl_unet.to(unet.device, dtype=unet.dtype)
|
||||||
ctrl_net.to(unet.device, dtype=unet.dtype)
|
ctrl_net.to(unet.device, dtype=unet.dtype)
|
||||||
@@ -117,7 +120,7 @@ def load_preprocess(prep_type: str):
|
|||||||
|
|
||||||
return canny
|
return canny
|
||||||
|
|
||||||
print("Unsupported prep type:", prep_type)
|
logger.info(f"Unsupported prep type: {prep_type}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -174,7 +177,7 @@ def call_unet_and_control_net(
|
|||||||
cnet_idx = step % cnet_cnt
|
cnet_idx = step % cnet_cnt
|
||||||
cnet_info = control_nets[cnet_idx]
|
cnet_info = control_nets[cnet_idx]
|
||||||
|
|
||||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||||
if cnet_info.ratio < current_ratio:
|
if cnet_info.ratio < current_ratio:
|
||||||
return original_unet(sample, timestep, encoder_hidden_states)
|
return original_unet(sample, timestep, encoder_hidden_states)
|
||||||
|
|
||||||
@@ -192,7 +195,7 @@ def call_unet_and_control_net(
|
|||||||
# ControlNet
|
# ControlNet
|
||||||
cnet_outs_list = []
|
cnet_outs_list = []
|
||||||
for i, cnet_info in enumerate(control_nets):
|
for i, cnet_info in enumerate(control_nets):
|
||||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||||
if cnet_info.ratio < current_ratio:
|
if cnet_info.ratio < current_ratio:
|
||||||
continue
|
continue
|
||||||
guided_hint = guided_hints[i]
|
guided_hint = guided_hints[i]
|
||||||
@@ -232,7 +235,7 @@ def unet_forward(
|
|||||||
upsample_size = None
|
upsample_size = None
|
||||||
|
|
||||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||||
print("Forward upsample size to force interpolation output size.")
|
logger.info("Forward upsample size to force interpolation output size.")
|
||||||
forward_upsample_size = True
|
forward_upsample_size = True
|
||||||
|
|
||||||
# 1. time
|
# 1. time
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ import shutil
|
|||||||
import math
|
import math
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
||||||
# Split the max_resolution string by "," and strip any whitespaces
|
# Split the max_resolution string by "," and strip any whitespaces
|
||||||
@@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
|||||||
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
|
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
|
||||||
|
|
||||||
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
||||||
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||||
|
|
||||||
# If other files with same basename, copy them with resolution suffix
|
# If other files with same basename, copy them with resolution suffix
|
||||||
if copy_associated_files:
|
if copy_associated_files:
|
||||||
@@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
|||||||
continue
|
continue
|
||||||
for max_resolution in max_resolutions:
|
for max_resolution in max_resolutions:
|
||||||
new_asoc_file = base + '+' + max_resolution + ext
|
new_asoc_file = base + '+' + max_resolution + ext
|
||||||
print(f"Copy {asoc_file} as {new_asoc_file}")
|
logger.info(f"Copy {asoc_file} as {new_asoc_file}")
|
||||||
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model", type=str, required=True)
|
parser.add_argument("--model", type=str, required=True)
|
||||||
@@ -10,10 +14,10 @@ with safe_open(args.model, framework="pt") as f:
|
|||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
|
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
print("No metadata found")
|
logger.error("No metadata found")
|
||||||
else:
|
else:
|
||||||
# metadata is json dict, but not pretty printed
|
# metadata is json dict, but not pretty printed
|
||||||
# sort by key and pretty print
|
# sort by key and pretty print
|
||||||
print(json.dumps(metadata, indent=4, sort_keys=True))
|
print(json.dumps(metadata, indent=4, sort_keys=True))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,10 @@ from library.custom_train_functions import (
|
|||||||
pyramid_noise_like,
|
pyramid_noise_like,
|
||||||
apply_noise_offset,
|
apply_noise_offset,
|
||||||
)
|
)
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||||
@@ -69,11 +72,11 @@ def train(args):
|
|||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
@@ -103,7 +106,7 @@ def train(args):
|
|||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
return
|
return
|
||||||
if len(train_dataset_group) == 0:
|
if len(train_dataset_group) == 0:
|
||||||
print(
|
logger.error(
|
||||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -114,7 +117,7 @@ def train(args):
|
|||||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
@@ -310,7 +313,7 @@ def train(args):
|
|||||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
@@ -567,7 +570,7 @@ def train(args):
|
|||||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||||
save_model(ckpt_name, controlnet, force_sync_upload=True)
|
save_model(ckpt_name, controlnet, force_sync_upload=True)
|
||||||
|
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
16
train_db.py
16
train_db.py
@@ -35,6 +35,10 @@ from library.custom_train_functions import (
|
|||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
apply_debiased_estimation,
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# perlin_noise,
|
# perlin_noise,
|
||||||
|
|
||||||
@@ -54,11 +58,11 @@ def train(args):
|
|||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "reg_data_dir"]
|
ignored = ["train_data_dir", "reg_data_dir"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
@@ -93,13 +97,13 @@ def train(args):
|
|||||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
|
|
||||||
if args.gradient_accumulation_steps > 1:
|
if args.gradient_accumulation_steps > 1:
|
||||||
print(
|
logger.warning(
|
||||||
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
|
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
|
||||||
)
|
)
|
||||||
print(
|
logger.warning(
|
||||||
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -449,7 +453,7 @@ def train(args):
|
|||||||
train_util.save_sd_model_on_train_end(
|
train_util.save_sd_model_on_train_end(
|
||||||
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
||||||
)
|
)
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -41,7 +41,10 @@ from library.custom_train_functions import (
|
|||||||
add_v_prediction_like_loss,
|
add_v_prediction_like_loss,
|
||||||
apply_debiased_estimation,
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class NetworkTrainer:
|
class NetworkTrainer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -153,18 +156,18 @@ class NetworkTrainer:
|
|||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
print(f"Loading dataset config from {args.dataset_config}")
|
logger.info(f"Loading dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.warning(
|
||||||
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if use_dreambooth_method:
|
if use_dreambooth_method:
|
||||||
print("Using DreamBooth method.")
|
logger.info("Using DreamBooth method.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -175,7 +178,7 @@ class NetworkTrainer:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print("Training with captions.")
|
logger.info("Training with captions.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -204,7 +207,7 @@ class NetworkTrainer:
|
|||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
return
|
return
|
||||||
if len(train_dataset_group) == 0:
|
if len(train_dataset_group) == 0:
|
||||||
print(
|
logger.error(
|
||||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -217,7 +220,7 @@ class NetworkTrainer:
|
|||||||
self.assert_extra_args(args, train_dataset_group)
|
self.assert_extra_args(args, train_dataset_group)
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("preparing accelerator")
|
logger.info("preparing accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
@@ -310,7 +313,7 @@ class NetworkTrainer:
|
|||||||
if hasattr(network, "prepare_network"):
|
if hasattr(network, "prepare_network"):
|
||||||
network.prepare_network(args)
|
network.prepare_network(args)
|
||||||
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||||
print(
|
logger.warning(
|
||||||
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
|
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
|
||||||
)
|
)
|
||||||
args.scale_weight_norms = False
|
args.scale_weight_norms = False
|
||||||
@@ -938,7 +941,7 @@ class NetworkTrainer:
|
|||||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||||
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
||||||
|
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -32,6 +32,10 @@ from library.custom_train_functions import (
|
|||||||
add_v_prediction_like_loss,
|
add_v_prediction_like_loss,
|
||||||
apply_debiased_estimation,
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
"a photo of a {}",
|
"a photo of a {}",
|
||||||
@@ -178,7 +182,7 @@ class TextualInversionTrainer:
|
|||||||
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
|
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
@@ -288,7 +292,7 @@ class TextualInversionTrainer:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print("Train with captions.")
|
logger.info("Train with captions.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -736,7 +740,7 @@ class TextualInversionTrainer:
|
|||||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||||
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
|
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
|
||||||
|
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -36,6 +36,10 @@ from library.custom_train_functions import (
|
|||||||
)
|
)
|
||||||
import library.original_unet as original_unet
|
import library.original_unet as original_unet
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
|
from library.utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
"a photo of a {}",
|
"a photo of a {}",
|
||||||
@@ -99,7 +103,7 @@ def train(args):
|
|||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
||||||
print(
|
logger.warning(
|
||||||
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
@@ -114,7 +118,7 @@ def train(args):
|
|||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
@@ -127,7 +131,7 @@ def train(args):
|
|||||||
if args.init_word is not None:
|
if args.init_word is not None:
|
||||||
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||||
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
||||||
print(
|
logger.warning(
|
||||||
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -141,7 +145,7 @@ def train(args):
|
|||||||
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
||||||
|
|
||||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
print(f"tokens are added: {token_ids}")
|
logger.info(f"tokens are added: {token_ids}")
|
||||||
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||||
|
|
||||||
@@ -169,7 +173,7 @@ def train(args):
|
|||||||
|
|
||||||
tokenizer.add_tokens(token_strings_XTI)
|
tokenizer.add_tokens(token_strings_XTI)
|
||||||
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
|
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
|
||||||
print(f"tokens are added (XTI): {token_ids_XTI}")
|
logger.info(f"tokens are added (XTI): {token_ids_XTI}")
|
||||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
@@ -178,7 +182,7 @@ def train(args):
|
|||||||
if init_token_ids is not None:
|
if init_token_ids is not None:
|
||||||
for i, token_id in enumerate(token_ids_XTI):
|
for i, token_id in enumerate(token_ids_XTI):
|
||||||
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
|
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
|
||||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
|
|
||||||
# load weights
|
# load weights
|
||||||
if args.weights is not None:
|
if args.weights is not None:
|
||||||
@@ -186,22 +190,22 @@ def train(args):
|
|||||||
assert len(token_ids) == len(
|
assert len(token_ids) == len(
|
||||||
embeddings
|
embeddings
|
||||||
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
||||||
# print(token_ids, embeddings.size())
|
# logger.info(token_ids, embeddings.size())
|
||||||
for token_id, embedding in zip(token_ids_XTI, embeddings):
|
for token_id, embedding in zip(token_ids_XTI, embeddings):
|
||||||
token_embeds[token_id] = embedding
|
token_embeds[token_id] = embedding
|
||||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
print(f"weighs loaded")
|
logger.info(f"weighs loaded")
|
||||||
|
|
||||||
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
logger.info(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
@@ -209,14 +213,14 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
use_dreambooth_method = args.in_json is None
|
use_dreambooth_method = args.in_json is None
|
||||||
if use_dreambooth_method:
|
if use_dreambooth_method:
|
||||||
print("Use DreamBooth method.")
|
logger.info("Use DreamBooth method.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print("Train with captions.")
|
logger.info("Train with captions.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -240,7 +244,7 @@ def train(args):
|
|||||||
|
|
||||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||||
if use_template:
|
if use_template:
|
||||||
print(f"use template for training captions. is object: {args.use_object_template}")
|
logger.info(f"use template for training captions. is object: {args.use_object_template}")
|
||||||
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
||||||
replace_to = " ".join(token_strings)
|
replace_to = " ".join(token_strings)
|
||||||
captions = []
|
captions = []
|
||||||
@@ -264,7 +268,7 @@ def train(args):
|
|||||||
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
||||||
return
|
return
|
||||||
if len(train_dataset_group) == 0:
|
if len(train_dataset_group) == 0:
|
||||||
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
||||||
return
|
return
|
||||||
|
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
@@ -297,7 +301,7 @@ def train(args):
|
|||||||
text_encoder.gradient_checkpointing_enable()
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
logger.info("prepare optimizer, data loader etc.")
|
||||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
|
||||||
@@ -318,7 +322,7 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
logger.info(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
# データセット側にも学習ステップを送信
|
||||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
@@ -332,7 +336,7 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|
||||||
# Freeze all parameters except for the token embeddings in text encoder
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
@@ -370,15 +374,15 @@ def train(args):
|
|||||||
|
|
||||||
# 学習する
|
# 学習する
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
print("running training / 学習開始")
|
logger.info("running training / 学習開始")
|
||||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
logger.info(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
global_step = 0
|
global_step = 0
|
||||||
@@ -403,7 +407,8 @@ def train(args):
|
|||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
logger.info("")
|
||||||
|
logger.info(f"saving checkpoint: {ckpt_file}")
|
||||||
save_weights(ckpt_file, embs, save_dtype)
|
save_weights(ckpt_file, embs, save_dtype)
|
||||||
if args.huggingface_repo_id is not None:
|
if args.huggingface_repo_id is not None:
|
||||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||||
@@ -411,12 +416,13 @@ def train(args):
|
|||||||
def remove_model(old_ckpt_name):
|
def remove_model(old_ckpt_name):
|
||||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||||
if os.path.exists(old_ckpt_file):
|
if os.path.exists(old_ckpt_file):
|
||||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
logger.info(f"removing old checkpoint: {old_ckpt_file}")
|
||||||
os.remove(old_ckpt_file)
|
os.remove(old_ckpt_file)
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
logger.info("")
|
||||||
|
logger.info(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
@@ -586,7 +592,7 @@ def train(args):
|
|||||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||||
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
|
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
|
||||||
|
|
||||||
print("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def save_weights(file, updated_embs, save_dtype):
|
def save_weights(file, updated_embs, save_dtype):
|
||||||
|
|||||||
Reference in New Issue
Block a user