Merge pull request #427 from kohya-ss/dev

fix lora_interrogator, wd14 tagger for '^_^' etc
This commit is contained in:
Kohya S
2023-04-19 21:57:34 +09:00
committed by GitHub
3 changed files with 35 additions and 16 deletions

View File

@@ -127,10 +127,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
### 17 Apr. 2023, 2023/4/17:
- Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf!
- `finetune`フォルダ内の各スクリプトに再起的にフォルダを処理するオプション`--recursive`を追加しました。詳細は [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) を参照してください。Linaqruf 氏に感謝します。
### 19 Apr. 2023, 2023/4/19:
- Fixed `lora_interrogator.py` not working. Please refer to [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) for details. Thank you A2va and heyalexchoi!
- Fixed the handling of tags containing `_` in `tag_images_by_wd14_tagger.py`.
- `lora_interrogator.py`が動作しなくなっていたのを修正しました。詳細は [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) を参照ください。A2va氏およびheyalexchoi氏に感謝します。
- `tag_images_by_wd14_tagger.py`で`_`を含むタグの取り扱いを修正しました。
### Naming of LoRA
@@ -164,6 +165,11 @@ LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additio
LoRA-C3Liarを使いWeb UIで生成するには拡張を使用してください。
### 17 Apr. 2023, 2023/4/17:
- Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf!
- `finetune`フォルダ内の各スクリプトに再起的にフォルダを処理するオプション`--recursive`を追加しました。詳細は [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) を参照してください。Linaqruf 氏に感謝します。
### 14 Apr. 2023, 2023/4/14:
- Fixed a bug that caused an error when loading DyLoRA with the `--network_weight` option in `train_network.py`.
- `train_network.py`で、DyLoRAを`--network_weight`オプションで読み込むとエラーになる不具合を修正しました。

View File

@@ -141,17 +141,19 @@ def main(args):
character_tag_text = ""
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i].replace("_", " ") if args.remove_underscore else general_tags[i]
tag_name = general_tags[i]
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = (
character_tags[i - len(general_tags)].replace("_", " ")
if args.remove_underscore
else character_tags[i - len(general_tags)]
)
tag_name = character_tags[i - len(general_tags)]
if args.remove_underscore and len(tag_name) > 3:
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name

View File

@@ -2,6 +2,7 @@
from tqdm import tqdm
from library import model_util
import library.train_util as train_util
import argparse
from transformers import CLIPTokenizer
import torch
@@ -16,16 +17,20 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def interrogate(args):
weights_dtype = torch.float16
# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
print(f"loading LoRA: {args.model}")
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
has_te_weight = False
for key in network.weights_sd.keys():
for key in weights_sd.keys():
if 'lora_te' in key:
has_te_weight = True
break
@@ -40,9 +45,9 @@ def interrogate(args):
else:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
text_encoder.to(DEVICE)
text_encoder.to(DEVICE, dtype=weights_dtype)
text_encoder.eval()
unet.to(DEVICE)
unet.to(DEVICE, dtype=weights_dtype)
unet.eval() # U-Netは呼び出さないので不要だけど
# トークンをひとつひとつ当たっていく
@@ -78,9 +83,14 @@ def interrogate(args):
orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
network.to(DEVICE)
info = network.load_state_dict(weights_sd, strict=False)
print(f"Loading LoRA weights: {info}")
network.to(DEVICE, dtype=weights_dtype)
network.eval()
del unet
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
@@ -107,6 +117,7 @@ def interrogate(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--sd_model", type=str, default=None,