Compare commits

..

27 Commits

Author SHA1 Message Date
Kohya S
5050971ac6 Merge pull request #388 from kohya-ss/dev
add weighted caption for training
2023-04-08 22:00:46 +09:00
Kohya S
08c54dcf22 update readme 2023-04-08 21:58:22 +09:00
Kohya S
6a5f87d874 disable weighted captions in TI/XTI training 2023-04-08 21:45:57 +09:00
Kohya S
a876f2d3fb format by black 2023-04-08 21:36:35 +09:00
Kohya S
a75f5898e6 Merge pull request #336 from AI-Casanova/weighted_captions
Proof of Concept: Weighted captions
2023-04-08 21:31:05 +09:00
AI-Casanova
dbab72153f Clean up custom_train_functions.py
Removed commented out lines from earlier bugfix.
2023-04-08 00:44:56 -05:00
AI-Casanova
0d54609435 Merge branch 'kohya-ss:main' into weighted_captions 2023-04-07 14:55:40 -05:00
Kohya S
b5c60d7d62 Merge pull request #381 from kohya-ss/dev
feature to upload to huggingface etc.
2023-04-06 08:20:07 +09:00
Kohya S
defefd79c5 Merge branch 'main' into dev 2023-04-06 08:16:31 +09:00
Kohya S
27834df444 update readme 2023-04-06 08:16:02 +09:00
Kohya S
5c020bed49 Add attension couple+reginal LoRA 2023-04-06 08:11:54 +09:00
Kohya S
c775ec1255 Add about using LoRA with Diffusers standard pipe 2023-04-06 08:10:41 +09:00
AI-Casanova
7527436549 Merge branch 'kohya-ss:main' into weighted_captions 2023-04-05 17:07:15 -05:00
Kohya S
541539a144 change method name, repo is private in default etc 2023-04-05 23:16:49 +09:00
Kohya S
74220bb52c Merge pull request #348 from ddPn08/dev
Added a function to upload to Huggingface and resume from Huggingface.
2023-04-05 21:47:36 +09:00
AI-Casanova
1892c82a60 Reinstantiate weighted captions after a necessary revert to Main 2023-04-02 19:43:34 +00:00
ddPn08
3f339cda6f small fix 2023-04-02 23:21:17 +09:00
ddPn08
16ba1cec69 change async uploading to optional 2023-04-02 17:45:26 +09:00
ddPn08
8bfa50e283 small fix 2023-04-02 17:39:23 +09:00
ddPn08
c4a11e5a5a fix help 2023-04-02 17:39:23 +09:00
ddPn08
3cc4939dd3 Implement huggingface upload for all scripts 2023-04-02 17:39:22 +09:00
ddPn08
b5c7937f8d don't run when not needed 2023-04-02 17:39:21 +09:00
ddPn08
b5ff4e816f resume from huggingface repository 2023-04-02 17:39:21 +09:00
ddPn08
a7d302e196 write a random seed to metadata 2023-04-02 17:39:20 +09:00
ddPn08
45381b188c small fix 2023-04-02 17:39:20 +09:00
ddPn08
054fb3308c use access token 2023-04-02 17:39:19 +09:00
ddPn08
d42431d73a Added feature to upload to huggingface 2023-04-02 17:39:10 +09:00
14 changed files with 1013 additions and 208 deletions

120
README.md
View File

@@ -127,91 +127,55 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
- 4 Apr. 2023, 2023/4/4:
- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while.
- The learning rate and dim (rank) of each block may not work with other modules (LyCORIS, etc.) because the module needs to be changed.
### 8 Apr. 2021, 2021/4/8:
- Fix some bugs and add some features.
- Fix an issue that `.json` format dataset config files cannot be read. [issue #351](https://github.com/kohya-ss/sd-scripts/issues/351) Thanks to rockerBOO!
- Raise an error when an invalid `--lr_warmup_steps` option is specified (when warmup is not valid for the specified scheduler). [PR #364](https://github.com/kohya-ss/sd-scripts/pull/364) Thanks to shirayu!
- Add `min_snr_gamma` to metadata in `train_network.py`. [PR #373](https://github.com/kohya-ss/sd-scripts/pull/373) Thanks to rockerBOO!
- Fix the data type handling in `fine_tune.py`. This may fix an error that occurs in some environments when using xformers, npz format cache, and mixed_precision.
- Added support for training with weighted captions. Thanks to AI-Casanova for the great contribution!
- Please refer to the PR for details: [PR #336](https://github.com/kohya-ss/sd-scripts/pull/336)
- Specify the `--weighted_captions` option. It is available for all training scripts except Textual Inversion and XTI.
- This option is also applicable to token strings of the DreamBooth method.
- The syntax for weighted captions is almost the same as the Web UI, and you can use things like `(abc)`, `[abc]`, and `(abc:1.23)`. Nesting is also possible.
- If you include a comma in the parentheses, the parentheses will not be properly matched in the prompt shuffle/dropout, so do not include a comma in the parentheses.
- Add options to `train_network.py` to specify block weights for learning rates. [PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) Thanks to u-haru for the great contribution!
- Specify the weights of 25 blocks for the full model.
- No LoRA corresponds to the first block, but 25 blocks are specified for compatibility with 'LoRA block weight' etc. Also, if you do not expand to conv2d3x3, some blocks do not have LoRA, but please specify 25 values for the argument for consistency.
- Specify the following arguments with `--network_args`.
- `down_lr_weight` : Specify the learning rate weight of the down blocks of U-Net. The following can be specified.
- The weight for each block: Specify 12 numbers such as `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"`.
- Specify from preset: Specify such as `"down_lr_weight=sine"` (the weights by sine curve). sine, cosine, linear, reverse_linear, zeros can be specified. Also, if you add `+number` such as `"down_lr_weight=cosine+.25"`, the specified number is added (such as 0.25~1.25).
- `mid_lr_weight` : Specify the learning rate weight of the mid block of U-Net. Specify one number such as `"down_lr_weight=0.5"`.
- `up_lr_weight` : Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.
- If you omit the some arguments, the 1.0 is used. Also, if you set the weight to 0, the LoRA modules of that block are not created.
- `block_lr_zero_threshold` : If the weight is not more than this value, the LoRA module is not created. The default is 0.
- Add options to `train_network.py` to specify block dims (ranks) for variable rank.
- Specify 25 values for the full model of 25 blocks. Some blocks do not have LoRA, but specify 25 values always.
- Specify the following arguments with `--network_args`.
- `block_dims` : Specify the dim (rank) of each block. Specify 25 numbers such as `"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`.
- `block_alphas` : Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.
- `conv_block_dims` : Expand LoRA to Conv2d 3x3 and specify the dim (rank) of each block.
- `conv_block_alphas` : Specify the alpha of each block when expanding LoRA to Conv2d 3x3. If omitted, the value of conv_alpha is used.
- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
- 階層別学習率、階層別dim(rank)についてはモジュール側の変更が必要なため、当リポジトリ内のnetworkモジュール以外LyCORISなどでは現在は動作しないと思われます。
- 重みづけキャプションによる学習に対応しました。 AI-Casanova 氏の素晴らしい貢献に感謝します。
- 詳細はこちらをご確認ください。[PR #336](https://github.com/kohya-ss/sd-scripts/pull/336)
- `--weighted_captions` オプションを指定してください。Textual InversionおよびXTIを除く学習スクリプトで使用可能です。
- キャプションだけでなく DreamBooth 手法の token string でも有効です。
- 重みづけキャプションの記法はWeb UIとほぼ同じで、`(abc)`や`[abc]`、`(abc:1.23)`などが使用できます。入れ子も可能です。
- 括弧内にカンマを含めるとプロンプトのshuffle/dropoutで括弧の対応付けがおかしくなるため、括弧内にはカンマを含めないでください。
- いくつかのバグ修正、機能追加を行いました。
- `.json`形式のdataset設定ファイルを読み込めない不具合を修正しました。 [issue #351](https://github.com/kohya-ss/sd-scripts/issues/351) rockerBOO 氏に感謝します。
- 無効な`--lr_warmup_steps` オプション指定したスケジューラでwarmupが無効な場合を指定している場合にエラーを出すようにしました。 [PR #364](https://github.com/kohya-ss/sd-scripts/pull/364) shirayu 氏に感謝します。
- `train_network.py` で `min_snr_gamma` をメタデータに追加しました。 [PR #373](https://github.com/kohya-ss/sd-scripts/pull/373) rockerBOO 氏に感謝します。
- `fine_tune.py` でデータ型の取り扱いが誤っていたのを修正しました。一部の環境でxformersを使い、npz形式のキャッシュ、mixed_precisionで学習した時にエラーとなる不具合が解消されるかもしれません。
- 階層別学習率を `train_network.py` で指定できるようになりました。[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) u-haru 氏の多大な貢献に感謝します。
- フルモデルの25個のブロックの重みを指定できます。
- 最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
-`--network_args` で以下の引数を指定してください。
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定しますサインカーブで重みを指定します。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します0.25~1.25になります)。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
### 6 Apr. 2023, 2023/4/6:
- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while.
- 階層別dim (rank)を `train_network.py` で指定できるようになりました。
- フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
- `--network_args` で以下の引数を指定してください。
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。
- `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。
- `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。
- `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。
- 階層別学習率コマンドライン指定例 / Examples of block learning rate command line specification:
` --network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"`
` --network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"`
- 階層別学習率tomlファイル指定例 / Examples of block learning rate toml file specification
`network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",]`
`network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ]`
- Added a feature to upload model and state to HuggingFace. Thanks to ddPn08 for the contribution! [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348)
- When `--huggingface_repo_id` is specified, the model is uploaded to HuggingFace at the same time as saving the model.
- Please note that the access token is handled with caution. Please refer to the [HuggingFace documentation](https://huggingface.co/docs/hub/security-tokens).
- For example, specify other arguments as follows.
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
- If `public` is specified for `--huggingface_repo_visibility`, the repository will be public. If the option is omitted or `private` (or anything other than `public`) is specified, it will be private.
- If you specify `--save_state` and `--save_state_to_huggingface`, the state will also be uploaded.
- If you specify `--resume` and `--resume_from_huggingface`, the state will be downloaded from HuggingFace and resumed.
- In this case, the `--resume` option is `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`. For example: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
- If you specify `--async_upload`, the upload will be done asynchronously.
- Added the documentation for applying LoRA to generate with the standard pipeline of Diffusers. [training LoRA](./train_network_README-ja.md#diffusersのpipelineで生成する) (Japanese only)
- Support for Attention Couple and regional LoRA in `gen_img_diffusers.py`.
- If you use ` AND ` to separate the prompts, each sub-prompt is sequentially applied to LoRA. `--mask_path` is treated as a mask image. The number of sub-prompts and the number of LoRA must match.
- 階層別dim (rank)コマンドライン指定例 / Examples of block dim (rank) command line specification:
- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"`
` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`
` --network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"`
- 階層別dim (rank)tomlファイル指定例 / Examples of block dim (rank) toml file specification
`network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",]`
`network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",]`
- モデルおよびstateをHuggingFaceにアップロードする機能を各スクリプトに追加しました。 [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348) ddPn08 氏の貢献に感謝します。
- `--huggingface_repo_id`が指定されているとモデル保存時に同時にHuggingFaceにアップロードします。
- アクセストークンの取り扱いに注意してください。[HuggingFaceのドキュメント](https://huggingface.co/docs/hub/security-tokens)を参照してください。
- 他の引数をたとえば以下のように指定してください。
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
- `--huggingface_repo_visibility`に`public`を指定するとリポジトリが公開されます。省略時または`private`(など`public`以外)を指定すると非公開になります。
- `--save_state`オプション指定時に`--save_state_to_huggingface`を指定するとstateもアップロードします。
- `--resume`オプション指定時に`--resume_from_huggingface`を指定するとHuggingFaceからstateをダウンロードして再開します。
- その時の `--resume`オプションは `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`になります。例: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
- `--async_upload`オプションを指定するとアップロードを非同期で行います。
- [LoRAの文書](./train_network_README-ja.md#diffusersのpipelineで生成する)に、LoRAを適用してDiffusersの標準的なパイプラインで生成する方法を追記しました。
- `gen_img_diffusers.py` で Attention Couple および領域別LoRAに対応しました。
- プロンプトを` AND `で区切ると各サブプロンプトが順にLoRAに適用されます。`--mask_path` がマスク画像として扱われます。サブプロンプトの数とLoRAの数は一致している必要があります。
## Sample image generation during training

View File

@@ -21,7 +21,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
def train(args):
@@ -231,9 +231,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -284,10 +282,19 @@ def train(args):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
@@ -427,4 +434,4 @@ if __name__ == "__main__":
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)
train(args)

View File

@@ -92,6 +92,7 @@ from PIL.PngImagePlugin import PngInfo
import library.model_util as model_util
import library.train_util as train_util
from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
@@ -634,6 +635,7 @@ class PipelineLike:
img2img_noise=None,
clip_prompts=None,
clip_guide_images=None,
networks: Optional[List[LoRANetwork]] = None,
**kwargs,
):
r"""
@@ -717,6 +719,7 @@ class PipelineLike:
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
reginonal_network = " AND " in prompt[0]
vae_batch_size = (
batch_size
@@ -1010,6 +1013,11 @@ class PipelineLike:
# predict the noise residual
if self.control_nets:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net(
i,
num_latent_input,
@@ -1019,7 +1027,7 @@ class PipelineLike:
i / len(timesteps),
latent_model_input,
t,
text_embeddings,
text_emb_last,
).sample
else:
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -1890,6 +1898,12 @@ def get_weighted_text_embeddings(
if isinstance(prompt, str):
prompt = [prompt]
# split the prompts with "AND". each prompt must have the same number of splits
new_prompts = []
for p in prompt:
new_prompts.extend(p.split(" AND "))
prompt = new_prompts
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
if uncond_prompt is not None:
@@ -2059,6 +2073,7 @@ class BatchDataExt(NamedTuple):
negative_scale: float
strength: float
network_muls: Tuple[float]
num_sub_prompts: int
class BatchData(NamedTuple):
@@ -2276,16 +2291,20 @@ def main(args):
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return
if not args.network_merge:
mergiable = hasattr(network, "merge_to")
if args.network_merge and not mergiable:
print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergiable:
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
if args.opt_channels_last:
@@ -2349,12 +2368,12 @@ def main(args):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# Textual Inversionを処理する
if args.textual_inversion_embeddings:
token_ids_embeds = []
for embeds_file in args.textual_inversion_embeddings:
@@ -2558,16 +2577,22 @@ def main(args):
print(f"resize img2img mask images to {args.W}*{args.H}")
mask_images = resize_images(mask_images, (args.W, args.H))
regional_network = False
if networks and mask_images:
# mask を領域情報として流用する、現在は1枚だけ対応
# TODO 複数のnetwork classの混在時の考慮
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True
print("use mask as region")
# import cv2
# for i in range(3):
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
# cv2.waitKey()
# cv2.destroyAllWindows()
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
size = None
for i, network in enumerate(networks):
if i < 3:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None
prev_image = None # for VGG16 guided
@@ -2623,7 +2648,14 @@ def main(args):
height_1st = height_1st - height_1st % 32
ext_1st = BatchDataExt(
width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls
width_1st,
height_1st,
args.highres_fix_steps,
ext.scale,
ext.negative_scale,
ext.strength,
ext.network_muls,
ext.num_sub_prompts,
)
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
images_1st = process_batch(batch_1st, True, True)
@@ -2651,7 +2683,7 @@ def main(args):
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image),
(width, height, steps, scale, negative_scale, strength, network_muls),
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
@@ -2743,8 +2775,11 @@ def main(args):
# generate
if networks:
shared = {}
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
images = pipe(
prompts,
@@ -2969,11 +3004,26 @@ def main(args):
print("Use previous image as guide image.")
guide_image = prev_image
if regional_network:
num_sub_prompts = len(prompt.split(" AND "))
assert (
len(networks) <= num_sub_prompts
), "Number of networks must be less than or equal to number of sub prompts."
else:
num_sub_prompts = None
b1 = BatchData(
False,
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
BatchDataExt(
width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None
width,
height,
steps,
scale,
negative_scale,
strength,
tuple(network_muls) if network_muls else None,
num_sub_prompts,
),
)
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
@@ -3197,6 +3247,9 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*",
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
)
# parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )
return parser

View File

@@ -1,18 +1,344 @@
import torch
import argparse
import re
from typing import List, Optional, Union
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
loss = loss * snr_weight
return loss
def add_custom_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨")
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper
loss = loss * snr_weight
return loss
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
parser.add_argument(
"--min_snr_gamma",
type=float,
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
action="store_true",
default=False,
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
)
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
tokenizer,
text_encoder,
text_input: torch.Tensor,
chunk_length: int,
clip_skip: int,
eos: int,
pad: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1]
else: # v2
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
if clip_skip is None or clip_skip == 1:
text_embedding = text_encoder(text_input_chunk)[0]
else:
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = text_encoder(text_input)[0]
return text_embeddings
def get_weighted_text_embeddings(
tokenizer,
text_encoder,
prompt: Union[str, List[str]],
device,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
clip_skip=None,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = tokenizer.bos_token_id
eos = tokenizer.eos_token_id
pad = tokenizer.pad_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=tokenizer.model_max_length,
)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
tokenizer,
text_encoder,
prompt_tokens,
tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
# assign weights to the prompts and normalize in the sense of mean
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
return text_embeddings

View File

@@ -0,0 +1,78 @@
from typing import *
from huggingface_hub import HfApi
from pathlib import Path
import argparse
import os
from library.utils import fire_in_thread
def exists_repo(
repo_id: str, repo_type: str, revision: str = "main", token: str = None
):
api = HfApi(
token=token,
)
try:
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
return True
except:
return False
def upload(
args: argparse.Namespace,
src: Union[str, Path, bytes, BinaryIO],
dest_suffix: str = "",
force_sync_upload: bool = False,
):
repo_id = args.huggingface_repo_id
repo_type = args.huggingface_repo_type
token = args.huggingface_token
path_in_repo = args.huggingface_path_in_repo + dest_suffix
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
api = HfApi(token=token)
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
is_folder = (type(src) == str and os.path.isdir(src)) or (
isinstance(src, Path) and src.is_dir()
)
def uploader():
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)
else:
uploader()
def list_dir(
repo_id: str,
subfolder: str,
repo_type: str,
revision: str = "main",
token: str = None,
):
api = HfApi(
token=token,
)
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
file_list = [
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
]
return file_list

View File

@@ -2,6 +2,7 @@
import argparse
import ast
import asyncio
import importlib
import json
import pathlib
@@ -49,6 +50,7 @@ from diffusers import (
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
)
from huggingface_hub import hf_hub_download
import albumentations as albu
import numpy as np
from PIL import Image
@@ -58,6 +60,7 @@ from torch import einsum
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -487,7 +490,7 @@ class BaseDataset(torch.utils.data.Dataset):
else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
if subset.token_warmup_step < 1: # 初回に上書きする
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
@@ -950,10 +953,10 @@ class BaseDataset(torch.utils.data.Dataset):
example["images"] = images
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
example["captions"] = captions
return example
@@ -1441,7 +1444,6 @@ def glob_images_pathlib(dir_path, recursive):
# endregion
# region モジュール入れ替え部
"""
高速化のためのモジュール入れ替え
@@ -1896,6 +1898,38 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
parser.add_argument(
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名"
)
parser.add_argument(
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
)
parser.add_argument(
"--huggingface_path_in_repo",
type=str,
default=None,
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
)
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
parser.add_argument(
"--huggingface_repo_visibility",
type=str,
default=None,
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定'public'で公開、'private'またはNoneで非公開",
)
parser.add_argument(
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
)
parser.add_argument(
"--resume_from_huggingface",
action="store_true",
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
)
parser.add_argument(
"--async_upload",
action="store_true",
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
)
parser.add_argument(
"--save_precision",
type=str,
@@ -2261,6 +2295,57 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
# region utils
def resume_from_local_or_hf_if_specified(accelerator, args):
if not args.resume:
return
if not args.resume_from_huggingface:
print(f"resume training from local state: {args.resume}")
accelerator.load_state(args.resume)
return
print(f"resume training from huggingface state: {args.resume}")
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
path_in_repo = "/".join(args.resume.split("/")[2:])
revision = None
repo_type = None
if ":" in path_in_repo:
divided = path_in_repo.split(":")
if len(divided) == 2:
path_in_repo, revision = divided
repo_type = "model"
else:
path_in_repo, revision, repo_type = divided
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
list_files = huggingface_util.list_dir(
repo_id=repo_id,
subfolder=path_in_repo,
revision=revision,
token=args.huggingface_token,
repo_type=repo_type,
)
async def download(filename) -> str:
def task():
return hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
repo_type=repo_type,
token=args.huggingface_token,
)
return await asyncio.get_event_loop().run_in_executor(None, task)
loop = asyncio.get_event_loop()
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
if len(results) == 0:
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
dirname = os.path.dirname(results[0])
accelerator.load_state(dirname)
def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
@@ -2645,7 +2730,7 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype, device='cpu'):
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
@@ -2772,6 +2857,8 @@ def save_sd_model_on_epoch_end(
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_sd(old_epoch_no):
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
@@ -2791,6 +2878,8 @@ def save_sd_model_on_epoch_end(
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name)
def remove_du(old_epoch_no):
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
@@ -2808,7 +2897,11 @@ def save_sd_model_on_epoch_end(
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
if last_n_epochs is not None:
@@ -2819,6 +2912,17 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
shutil.rmtree(state_dir_old)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading last state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
def save_sd_model_on_train_end(
args: argparse.Namespace,
src_path: str,
@@ -2843,6 +2947,8 @@ def save_sd_model_on_train_end(
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
else:
out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True)
@@ -2851,13 +2957,8 @@ def save_sd_model_on_train_end(
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
# scheduler:
@@ -3089,7 +3190,7 @@ class collater_class:
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
def __call__(self, examples):
worker_info = torch.utils.data.get_worker_info()
@@ -3102,4 +3203,4 @@ class collater_class:
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]

6
library/utils.py Normal file
View File

@@ -0,0 +1,6 @@
import threading
from typing import *
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()

View File

@@ -10,7 +10,6 @@ import numpy as np
import torch
import re
from library import train_util
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -61,8 +60,6 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.region = None
self.region_mask = None
def apply_to(self):
self.org_forward = self.org_module.forward
@@ -105,39 +102,187 @@ class LoRAModule(torch.nn.Module):
self.region_mask = None
def forward(self, x):
if self.region is None:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# regional LoRA FIXME same as additional-network extension
if x.size()[1] % 77 == 0:
# print(f"LoRA for context: {self.lora_name}")
self.region = None
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
# calculate region mask first time
if self.region_mask is None:
if len(x.size()) == 4:
h, w = x.size()[2:4]
else:
seq_len = x.size()[1]
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
h = int(self.region.size()[0] / ratio + 0.5)
w = seq_len // h
class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
r = self.region.to(x.device)
if r.dtype == torch.bfloat16:
r = r.to(torch.float)
r = r.unsqueeze(0).unsqueeze(1)
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
r = r.to(x.dtype)
# check regional or not by lora_name
self.text_encoder = False
if lora_name.startswith("lora_te_"):
self.regional = False
self.use_sub_prompt = True
self.text_encoder = True
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
self.regional = False
self.use_sub_prompt = True
elif "time_emb" in lora_name:
self.regional = False
self.use_sub_prompt = False
else:
self.regional = True
self.use_sub_prompt = False
if len(x.size()) == 3:
r = torch.reshape(r, (1, x.size()[1], -1))
self.network: LoRANetwork = None
self.region_mask = r
def set_network(self, network):
self.network = network
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x):
if self.network is None or self.network.sub_prompt_index is None:
return self.default_forward(x)
if not self.regional and not self.use_sub_prompt:
return self.default_forward(x)
if self.regional:
return self.regional_forward(x)
else:
return self.sub_prompt_forward(x)
def get_mask_for_x(self, x):
# calculate size from shape of x
if len(x.size()) == 4:
h, w = x.size()[2:4]
area = h * w
else:
area = x.size()[1]
mask = self.network.mask_dic[area]
if mask is None:
raise ValueError(f"mask is None for resolution {area}")
if len(x.size()) != 4:
mask = torch.reshape(mask, (1, -1, 1))
return mask
def regional_forward(self, x):
if "attn2_to_out" in self.lora_name:
return self.to_out_forward(x)
if self.network.mask_dic is None: # sub_prompt_index >= 3
return self.default_forward(x)
# apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
lx = lx * mask
x = self.org_forward(x)
x = x + lx
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
x = self.postp_to_q(x)
return x
def postp_to_q(self, x):
# repeat x to num_sub_prompts
has_real_uncond = x.size()[0] // self.network.batch_size == 3
qc = self.network.batch_size # uncond
qc += self.network.batch_size * self.network.num_sub_prompts # cond
if has_real_uncond:
qc += self.network.batch_size # real_uncond
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
query[: self.network.batch_size] = x[: self.network.batch_size]
for i in range(self.network.batch_size):
qi = self.network.batch_size + i * self.network.num_sub_prompts
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
return query
def sub_prompt_forward(self, x):
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
return self.org_forward(x)
emb_idx = self.network.sub_prompt_index
if not self.text_encoder:
emb_idx += self.network.batch_size
# apply sub prompt of X
lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx
return x
def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts
self.network.shared[self.lora_name] = (None, masks)
else:
lx, masks = self.network.shared[self.lora_name]
# call own LoRA
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
if self.network.is_last_network:
lx = torch.zeros(
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
)
self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
# if not last network, return x and masks
x = self.org_forward(x)
if not self.network.is_last_network:
return x
lx, masks = self.network.shared.pop(self.lora_name)
# if last network, combine separated x with mask weighted sum
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# for i in range(len(masks)):
# if masks[i] is None:
# masks[i] = torch.zeros_like(masks[-1])
mask = torch.cat(masks)
mask_sum = torch.sum(mask, dim=0) + 1e-4
for i in range(self.network.batch_size):
# 1枚の画像ごとに処理する
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
lx1 = lx1 * mask
lx1 = torch.sum(lx1, dim=0)
xi = self.network.batch_size + i * self.network.num_sub_prompts
x1 = x[xi : xi + self.network.num_sub_prompts]
x1 = x1 * mask
x1 = torch.sum(x1, dim=0)
x1 = x1 / mask_sum
x1 = x1 + lx1
out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
return out
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
@@ -421,7 +566,7 @@ def get_block_index(lora_name: str) -> int:
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
@@ -450,7 +595,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
if key not in modules_alpha:
modules_alpha = modules_dim[key]
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
)
return network, weights_sd
@@ -479,6 +628,7 @@ class LoRANetwork(torch.nn.Module):
conv_block_alphas=None,
modules_dim=None,
modules_alpha=None,
module_class=LoRAModule,
varbose=False,
) -> None:
"""
@@ -554,7 +704,7 @@ class LoRANetwork(torch.nn.Module):
skipped.append(lora_name)
continue
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
loras.append(lora)
return loras, skipped
@@ -570,7 +720,7 @@ class LoRANetwork(torch.nn.Module):
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
if varbose and len(skipped) > 0:
print(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
@@ -600,7 +750,7 @@ class LoRANetwork(torch.nn.Module):
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
@@ -750,6 +900,7 @@ class LoRANetwork(torch.nn.Module):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
@@ -762,17 +913,45 @@ class LoRANetwork(torch.nn.Module):
else:
torch.save(state_dict, file)
@staticmethod
def set_regions(networks, image):
image = image.astype(np.float32) / 255.0
for i, network in enumerate(networks[:3]):
# NOTE: consider averaging overwrapping area
region = image[:, :, i]
if region.max() == 0:
continue
region = torch.tensor(region)
network.set_region(region)
# mask is a tensor with values from 0 to 1
def set_region(self, sub_prompt_index, is_last_network, mask):
if mask.max() == 0:
mask = torch.ones_like(mask)
def set_region(self, region):
for lora in self.unet_loras:
lora.set_region(region)
self.mask = mask
self.sub_prompt_index = sub_prompt_index
self.is_last_network = is_last_network
for lora in self.text_encoder_loras + self.unet_loras:
lora.set_network(self)
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
self.batch_size = batch_size
self.num_sub_prompts = num_sub_prompts
self.current_size = (height, width)
self.shared = shared
# create masks
mask = self.mask
mask_dic = {}
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
dtype = ref_weight.dtype
device = ref_weight.device
def resize_add(mh, mw):
# print(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m
h = height // 8
w = width // 8
for _ in range(4):
resize_add(h, w)
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
resize_add(h + h % 2, w + w % 2)
h = (h + 1) // 2
w = (w + 1) // 2
self.mask_dic = mask_dic

View File

@@ -21,6 +21,6 @@ fairscale==0.4.13
# for WD14 captioning
# tensorflow<2.11
tensorflow==2.10.1
huggingface-hub==0.12.0
huggingface-hub==0.13.3
# for kohya_ss library
.

View File

@@ -23,8 +23,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
def train(args):
train_util.verify_training_args(args)
@@ -202,9 +201,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -273,10 +270,19 @@ def train(args):
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@@ -426,4 +432,4 @@ if __name__ == "__main__":
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)
train(args)

View File

@@ -24,8 +24,9 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
# TODO 他のスクリプトと共通化する
@@ -71,8 +72,9 @@ def train(args):
use_dreambooth_method = args.in_json is None
use_user_config = args.dataset_config is not None
if args.seed is not None:
set_seed(args.seed)
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
@@ -308,9 +310,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -562,9 +562,17 @@ def train(args):
with torch.set_grad_enabled(train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
@@ -650,6 +658,8 @@ def train(args):
metadata["ss_training_finished_at"] = str(time.time())
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -689,6 +699,8 @@ def train(args):
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.")
@@ -745,4 +757,4 @@ if __name__ == "__main__":
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)
train(args)

View File

@@ -188,6 +188,73 @@ gen_img_diffusers.pyに、--network_module、--network_weightsの各オプショ
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
## Diffusersのpipelineで生成する
以下の例を参考にしてください。必要なファイルはnetworks/lora.pyのみです。Diffusersのバージョンは0.10.2以外では動作しない可能性があります。
```python
import torch
from diffusers import StableDiffusionPipeline
from networks.lora import LoRAModule, create_network_from_weights
from safetensors.torch import load_file
# if the ckpt is CompVis based, convert it to Diffusers beforehand with tools/convert_diffusers20_original_sd.py. See --help for more details.
model_id_or_dir = r"model_id_on_hugging_face_or_dir"
device = "cuda"
# create pipe
print(f"creating pipe from {model_id_or_dir}...")
pipe = StableDiffusionPipeline.from_pretrained(model_id_or_dir, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
vae = pipe.vae
text_encoder = pipe.text_encoder
unet = pipe.unet
# load lora networks
print(f"loading lora networks...")
lora_path1 = r"lora1.safetensors"
sd = load_file(lora_path1) # If the file is .ckpt, use torch.load instead.
network1, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
network1.apply_to(text_encoder, unet)
network1.load_state_dict(sd)
network1.to(device, dtype=torch.float16)
# # You can merge weights instead of apply_to+load_state_dict. network.set_multiplier does not work
# network.merge_to(text_encoder, unet, sd)
lora_path2 = r"lora2.safetensors"
sd = load_file(lora_path2)
network2, sd = create_network_from_weights(0.7, None, vae, text_encoder,unet, sd)
network2.apply_to(text_encoder, unet)
network2.load_state_dict(sd)
network2.to(device, dtype=torch.float16)
lora_path3 = r"lora3.safetensors"
sd = load_file(lora_path3)
network3, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
network3.apply_to(text_encoder, unet)
network3.load_state_dict(sd)
network3.to(device, dtype=torch.float16)
# prompts
prompt = "masterpiece, best quality, 1girl, in white shirt, looking at viewer"
negative_prompt = "bad quality, worst quality, bad anatomy, bad hands"
# exec pipe
print("generating image...")
with torch.autocast("cuda"):
image = pipe(prompt, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
# if not merged, you can use set_multiplier
# network1.set_multiplier(0.8)
# and generate image again...
# save image
image.save(r"by_diffusers..png")
```
## 二つのモデルの差分からLoRAモデルを作成する
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。

View File

@@ -13,6 +13,7 @@ import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.huggingface_util as huggingface_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
@@ -304,9 +305,7 @@ def train(args):
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -452,6 +451,8 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -492,6 +493,8 @@ def train(args):
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.")
@@ -546,7 +549,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
parser.add_argument(
"--save_model_as",

View File

@@ -13,6 +13,7 @@ import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.huggingface_util as huggingface_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
@@ -340,9 +341,7 @@ def train(args):
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -493,6 +492,8 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -534,6 +535,8 @@ def train(args):
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.")
@@ -600,7 +603,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
parser.add_argument(
"--save_model_as",