mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Merge branch 'original-u-net' into dev
This commit is contained in:
118
README.md
118
README.md
@@ -75,8 +75,6 @@ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_set
|
||||
accelerate config
|
||||
```
|
||||
|
||||
update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python).
|
||||
|
||||
Answers to accelerate config:
|
||||
|
||||
```txt
|
||||
@@ -94,6 +92,30 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
### Experimental: Use PyTorch 2.0
|
||||
|
||||
In this case, you need to install PyTorch 2.0 and xformers 0.0.20. Instead of the above, please type the following:
|
||||
|
||||
```powershell
|
||||
git clone https://github.com/kohya-ss/sd-scripts.git
|
||||
cd sd-scripts
|
||||
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install xformers==0.0.20
|
||||
|
||||
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
||||
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Answers to accelerate config should be the same as above.
|
||||
|
||||
### about PyTorch and xformers
|
||||
|
||||
Other versions of PyTorch and xformers seem to have problems with training.
|
||||
@@ -140,6 +162,98 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
### 15 Jun. 2023, 2023/06/15
|
||||
|
||||
- Prodigy optimizer is supported in each training script. It is a member of D-Adaptation and is effective for DyLoRA training. [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) Please see the PR for details. Thanks to sdbds!
|
||||
- Install the package with `pip install prodigyopt`. Then specify the option like `--optimizer_type="prodigy"`.
|
||||
- Arbitrary Dataset is supported in each training script (except XTI). You can use it by defining a Dataset class that returns images and captions.
|
||||
- Prepare a Python script and define a class that inherits `train_util.MinimalDataset`. Then specify the option like `--dataset_class package.module.DatasetClass` in each training script.
|
||||
- Please refer to `MinimalDataset` for implementation. I will prepare a sample later.
|
||||
- The following features have been added to the generation script.
|
||||
- Added an option `--highres_fix_disable_control_net` to disable ControlNet in the 2nd stage of Highres. Fix. Please try it if the image is disturbed by some ControlNet such as Canny.
|
||||
- Added Variants similar to sd-dynamic-propmpts in the prompt.
|
||||
- If you specify `{spring|summer|autumn|winter}`, one of them will be randomly selected.
|
||||
- If you specify `{2$$chocolate|vanilla|strawberry}`, two of them will be randomly selected.
|
||||
- If you specify `{1-2$$ and $$chocolate|vanilla|strawberry}`, one or two of them will be randomly selected and connected by ` and `.
|
||||
- You can specify the number of candidates in the range `0-2`. You cannot omit one side like `-2` or `1-`.
|
||||
- It can also be specified for the prompt option.
|
||||
- If you specify `e` or `E`, all candidates will be selected and the prompt will be repeated multiple times (`--images_per_prompt` is ignored). It may be useful for creating X/Y plots.
|
||||
- You can also specify `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`. In this case, 15 prompts will be generated with 5*3.
|
||||
- There is no weighting function.
|
||||
|
||||
- 各学習スクリプトでProdigyオプティマイザがサポートされました。D-Adaptationの仲間でDyLoRAの学習に有効とのことです。 [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) 詳細はPRをご覧ください。sdbds氏に感謝します。
|
||||
- `pip install prodigyopt` としてパッケージをインストールしてください。また `--optimizer_type="prodigy"` のようにオプションを指定します。
|
||||
- 各学習スクリプトで任意のDatasetをサポートしました(XTIを除く)。画像とキャプションを返すDatasetクラスを定義することで、学習スクリプトから利用できます。
|
||||
- Pythonスクリプトを用意し、`train_util.MinimalDataset`を継承するクラスを定義してください。そして各学習スクリプトのオプションで `--dataset_class package.module.DatasetClass` のように指定してください。
|
||||
- 実装方法は `MinimalDataset` を参考にしてください。のちほどサンプルを用意します。
|
||||
- 生成スクリプトに以下の機能追加を行いました。
|
||||
- Highres. Fixの2nd stageでControlNetを無効化するオプション `--highres_fix_disable_control_net` を追加しました。Canny等一部のControlNetで画像が乱れる場合にお試しください。
|
||||
- プロンプトでsd-dynamic-propmptsに似たVariantをサポートしました。
|
||||
- `{spring|summer|autumn|winter}` のように指定すると、いずれかがランダムに選択されます。
|
||||
- `{2$$chocolate|vanilla|strawberry}` のように指定すると、いずれか2個がランダムに選択されます。
|
||||
- `{1-2$$ and $$chocolate|vanilla|strawberry}` のように指定すると、1個か2個がランダムに選択され ` and ` で接続されます。
|
||||
- 個数のレンジ指定では`0-2`のように0個も指定可能です。`-2`や`1-`のような片側の省略はできません。
|
||||
- プロンプトオプションに対しても指定可能です。
|
||||
- `{e$$chocolate|vanilla|strawberry}` のように`e`または`E`を指定すると、すべての候補が選択されプロンプトが複数回繰り返されます(`--images_per_prompt`は無視されます)。X/Y plotの作成に便利かもしれません。
|
||||
- `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`のような指定も可能です。この場合、5*3で15回のプロンプトが生成されます。
|
||||
- Weightingの機能はありません。
|
||||
|
||||
### 8 Jun. 2023, 2023/06/08
|
||||
|
||||
- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training.
|
||||
- 重みづけキャプションでの学習時(`--weighted_captions`指定時)および学習中のサンプル画像生成時にclip skipが機能しない不具合を修正しました。
|
||||
|
||||
### 6 Jun. 2023, 2023/06/06
|
||||
|
||||
- Fix `train_network.py` to probably work with older versions of LyCORIS.
|
||||
- `gen_img_diffusers.py` now supports `BREAK` syntax.
|
||||
- `train_network.py`がLyCORISの以前のバージョンでも恐らく動作するよう修正しました。
|
||||
- `gen_img_diffusers.py` で `BREAK` 構文をサポートしました。
|
||||
|
||||
### 3 Jun. 2023, 2023/06/03
|
||||
|
||||
- Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
|
||||
- Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details.
|
||||
- Specify as `--scale_weight_norms=1.0`. It seems good to try from `1.0`.
|
||||
- The networks other than LoRA in this repository (such as LyCORIS) do not support this option.
|
||||
|
||||
- Three types of dropout have been added to `train_network.py` and LoRA network.
|
||||
- Dropout is a technique to suppress overfitting and improve network performance by randomly setting some of the network outputs to 0.
|
||||
- `--network_dropout` is a normal dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Proposed in [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
|
||||
- `--network_dropout=0.1` specifies the dropout probability to `0.1`.
|
||||
- Note that the specification method is different from LyCORIS.
|
||||
- For LoRA network, `--network_args` can specify `rank_dropout` to dropout each rank with specified probability. Also `module_dropout` can be specified to dropout each module with specified probability.
|
||||
- Specify as `--network_args "rank_dropout=0.2" "module_dropout=0.1"`.
|
||||
- `--network_dropout`, `rank_dropout`, and `module_dropout` can be specified at the same time.
|
||||
- Values of 0.1 to 0.3 may be good to try. Values greater than 0.5 should not be specified.
|
||||
- `rank_dropout` and `module_dropout` are original techniques of this repository. Their effectiveness has not been verified yet.
|
||||
- The networks other than LoRA in this repository (such as LyCORIS) do not support these options.
|
||||
|
||||
- Added an option `--scale_v_pred_loss_like_noise_pred` to scale v-prediction loss like noise prediction in each training script.
|
||||
- By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.
|
||||
- See [this article](https://xrg.hatenablog.com/entry/2023/06/02/202418) by xrg for details (written in Japanese). Thanks to xrg for the great suggestion!
|
||||
|
||||
- Max Norm Regularizationが`train_network.py`で使えるようになりました。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) AI-Casanova氏に感謝します。
|
||||
- Max Norm Regularizationは、ネットワークの重みのノルムを制限することで、ネットワークの学習を安定させる手法です。LoRAの過学習の抑制、他のLoRAと併用した時の安定性の向上が期待できるかもしれません。詳細はPRを参照してください。
|
||||
- `--scale_weight_norms=1.0`のように `--scale_weight_norms` で指定してください。`1.0`から試すと良いようです。
|
||||
- LyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
|
||||
|
||||
- `train_network.py` およびLoRAに計三種類のdropoutを追加しました。
|
||||
- dropoutはネットワークの一部の出力をランダムに0にすることで、過学習の抑制、ネットワークの性能向上等を図る手法です。
|
||||
- `--network_dropout` はニューロン単位の通常のdropoutです。LoRAの場合、downの出力に対して適用されます。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) で提案されました。AI-Casanova氏に感謝します。
|
||||
- `--network_dropout=0.1` などとすることで、dropoutの確率を指定できます。
|
||||
- LyCORISとは指定方法が異なりますのでご注意ください。
|
||||
- LoRAの場合、`--network_args`に`rank_dropout`を指定することで各rankを指定確率でdropoutします。また同じくLoRAの場合、`--network_args`に`module_dropout`を指定することで各モジュールを指定確率でdropoutします。
|
||||
- `--network_args "rank_dropout=0.2" "module_dropout=0.1"` のように指定します。
|
||||
- `--network_dropout`、`rank_dropout` 、 `module_dropout` は同時に指定できます。
|
||||
- それぞれの値は0.1~0.3程度から試してみると良いかもしれません。0.5を超える値は指定しない方が良いでしょう。
|
||||
- `rank_dropout`および`module_dropout`は当リポジトリ独自の手法です。有効性の検証はまだ行っていません。
|
||||
- これらのdropoutはLyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
|
||||
|
||||
- 各学習スクリプトにv-prediction lossをnoise predictionと同様の値にスケールするオプション`--scale_v_pred_loss_like_noise_pred`を追加しました。
|
||||
- タイムステップに応じてlossをスケールすることで、 大域的なノイズの予測と局所的なノイズの予測の重みが同じになり、ディテールの改善が期待できるかもしれません。
|
||||
- 詳細はxrg氏のこちらの記事をご参照ください:[noise_predictionモデルとv_predictionモデルの損失 - 勾配降下党青年局](https://xrg.hatenablog.com/entry/2023/06/02/202418) xrg氏の素晴らしい記事に感謝します。
|
||||
|
||||
### 31 May 2023, 2023/05/31
|
||||
|
||||
- Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin!
|
||||
|
||||
210
XTI_hijack.py
210
XTI_hijack.py
@@ -2,132 +2,123 @@ import torch
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
|
||||
def unet_forward_XTI(self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
from library.original_unet import SampleOutput
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
def unet_forward_XTI(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Dict, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a dict instead of a plain tuple.
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
Returns:
|
||||
`SampleOutput` or `tuple`:
|
||||
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
||||
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
||||
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
# 64で割り切れないときはupsamplerにサイズを伝える
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
# logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
||||
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
||||
# time_projでキャストしておけばいいんじゃね?
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.config.num_class_embeds is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
down_i = 0
|
||||
for downsample_block in self.down_blocks:
|
||||
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||
# まあこちらのほうがわかりやすいかもしれない
|
||||
if downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2],
|
||||
)
|
||||
down_i += 2
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
down_i = 0
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
|
||||
)
|
||||
down_i += 2
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
|
||||
# 5. up
|
||||
up_i = 7
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
# 5. up
|
||||
up_i = 7
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
# if we have not reached the final block and need to forward the upsample size, we do it here
|
||||
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
if upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3],
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
up_i += 3
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||
)
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
up_i += 3
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||
)
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return SampleOutput(sample=sample)
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
def downblock_forward_XTI(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
@@ -166,6 +157,7 @@ def downblock_forward_XTI(
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
def upblock_forward_XTI(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -199,11 +191,11 @@ def upblock_forward_XTI(
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
||||
|
||||
|
||||
i += 1
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
return hidden_states
|
||||
|
||||
@@ -622,6 +622,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
- DAdaptAdanIP : 引数は同上
|
||||
- DAdaptLion : 引数は同上
|
||||
- DAdaptSGD : 引数は同上
|
||||
- Prodigy : https://github.com/konstmish/prodigy
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任意のオプティマイザ
|
||||
|
||||
|
||||
@@ -555,9 +555,10 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
- DAdaptAdam : 参数同上
|
||||
- DAdaptAdaGrad : 参数同上
|
||||
- DAdaptAdan : 参数同上
|
||||
- DAdaptAdanIP : 引数は同上
|
||||
- DAdaptAdanIP : 参数同上
|
||||
- DAdaptLion : 参数同上
|
||||
- DAdaptSGD : 参数同上
|
||||
- Prodigy : https://github.com/konstmish/prodigy
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任何优化器
|
||||
|
||||
|
||||
93
fine_tune.py
93
fine_tune.py
@@ -19,7 +19,14 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -33,33 +40,37 @@ def train(args):
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -82,7 +93,7 @@ def train(args):
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
@@ -132,7 +143,7 @@ def train(args):
|
||||
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
|
||||
accelerator.print("Disable Diffusers' xformers")
|
||||
set_diffusers_xformers_flag(unet, False)
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
@@ -259,6 +270,7 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
@@ -325,11 +337,16 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
if args.min_snr_gamma:
|
||||
# do not mean over batch dimension for snr weight
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
@@ -370,15 +387,15 @@ def train(args):
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
accelerator.unwrap_model(text_encoder),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
@@ -413,8 +430,8 @@ def train(args):
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
accelerator.unwrap_model(text_encoder),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
@@ -422,8 +439,8 @@ def train(args):
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import glob
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
@@ -11,6 +12,7 @@ import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
from blip.blip import blip_decoder
|
||||
import library.train_util as train_util
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,20 +5,37 @@ import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
|
||||
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
scale = snr_t / (snr_t + 1)
|
||||
|
||||
loss = loss * scale
|
||||
return loss
|
||||
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
|
||||
@@ -29,6 +46,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
default=None,
|
||||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_v_pred_loss_like_noise_pred",
|
||||
action="store_true",
|
||||
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
||||
)
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
@@ -243,11 +265,6 @@ def get_unweighted_text_embeddings(
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
@@ -262,7 +279,12 @@ def get_unweighted_text_embeddings(
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings = text_encoder(text_input)[0]
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embeddings = text_encoder(text_input)[0]
|
||||
else:
|
||||
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||
return text_embeddings
|
||||
|
||||
|
||||
@@ -434,46 +456,3 @@ def perlin_noise(noise, device, octaves):
|
||||
noise += noise_perlin # broadcast for each batch
|
||||
return noise / noise.std() # Scaled back to roughly unit variance
|
||||
"""
|
||||
|
||||
|
||||
def max_norm(state_dict, max_norm_value, device):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
alphakeys = []
|
||||
norms = []
|
||||
keys_scaled = 0
|
||||
|
||||
for key in state_dict.keys():
|
||||
if "lora_down" in key and "weight" in key:
|
||||
downkeys.append(key)
|
||||
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||
|
||||
for i in range(len(downkeys)):
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
up = state_dict[upkeys[i]].to(device)
|
||||
alpha = state_dict[alphakeys[i]].to(device)
|
||||
dim = down.shape[0]
|
||||
scale = alpha / dim
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
updown *= scale
|
||||
|
||||
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||
desired = torch.clamp(norm, max=max_norm_value)
|
||||
ratio = desired.cpu() / norm.cpu()
|
||||
sqrt_ratio = ratio**0.5
|
||||
if ratio != 1:
|
||||
keys_scaled += 1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
@@ -245,11 +245,6 @@ def get_unweighted_text_embeddings(
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
@@ -264,7 +259,12 @@ def get_unweighted_text_embeddings(
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||
else:
|
||||
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||
return text_embeddings
|
||||
|
||||
|
||||
@@ -517,6 +517,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
# clip_skip: int,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
import diffusers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||
NUM_TRAIN_TIMESTEPS = 1000
|
||||
@@ -126,17 +128,30 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
if diffusers.__version__ < "0.17.0":
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
else:
|
||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
@@ -191,8 +206,16 @@ def assign_to_checkpoint(
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
reshaping = False
|
||||
if diffusers.__version__ < "0.17.0":
|
||||
if "proj_attn.weight" in new_path:
|
||||
reshaping = True
|
||||
else:
|
||||
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
|
||||
reshaping = True
|
||||
|
||||
if reshaping:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
@@ -361,7 +384,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
|
||||
# SDのv2では1*1のconv2dがlinearに変わっている
|
||||
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
|
||||
if v2 and not config.get('use_linear_projection', False):
|
||||
if v2 and not config.get("use_linear_projection", False):
|
||||
linear_transformer_to_conv(new_checkpoint)
|
||||
|
||||
return new_checkpoint
|
||||
@@ -877,14 +900,24 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
("q.", "query."),
|
||||
("k.", "key."),
|
||||
("v.", "value."),
|
||||
("proj_out.", "proj_attn."),
|
||||
]
|
||||
if diffusers.__version__ < "0.17.0":
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
("q.", "query."),
|
||||
("k.", "key."),
|
||||
("v.", "value."),
|
||||
("proj_out.", "proj_attn."),
|
||||
]
|
||||
else:
|
||||
vae_conversion_map_attn = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("norm.", "group_norm."),
|
||||
("q.", "to_q."),
|
||||
("k.", "to_k."),
|
||||
("v.", "to_v."),
|
||||
("proj_out.", "to_out.0."),
|
||||
]
|
||||
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
for k, v in mapping.items():
|
||||
@@ -901,7 +934,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
# print(f"Reshaping {k} for SD format")
|
||||
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
|
||||
return new_state_dict
|
||||
@@ -998,10 +1031,31 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
else:
|
||||
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
||||
|
||||
logging.set_verbosity_error() # don't show annoying warning
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
# logging.set_verbosity_error() # don't show annoying warning
|
||||
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||
# logging.set_verbosity_warning()
|
||||
# print(f"config: {text_model.config}")
|
||||
cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=768,
|
||||
torch_dtype="float32",
|
||||
)
|
||||
text_model = CLIPTextModel._from_config(cfg)
|
||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||
print("loading text encoder:", info)
|
||||
|
||||
|
||||
1593
library/original_unet.py
Normal file
1593
library/original_unet.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -36,7 +36,6 @@ from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
import transformers
|
||||
import diffusers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
@@ -52,6 +51,7 @@ from diffusers import (
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
)
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import albumentations as albu
|
||||
import numpy as np
|
||||
@@ -65,6 +65,7 @@ import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
from library.attention_processors import FlashAttnProcessor
|
||||
from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
@@ -1828,6 +1829,76 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
return image_paths
|
||||
|
||||
|
||||
class MinimalDataset(BaseDataset):
|
||||
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
|
||||
self.num_train_images = 0 # update in subclass
|
||||
self.num_reg_images = 0 # update in subclass
|
||||
self.datasets = [self]
|
||||
self.batch_size = 1 # update in subclass
|
||||
|
||||
self.subsets = [self]
|
||||
self.num_repeats = 1 # update in subclass if needed
|
||||
self.img_count = 1 # update in subclass if needed
|
||||
self.bucket_info = {}
|
||||
self.is_reg = False
|
||||
self.image_dir = "dummy" # for metadata
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# override to avoid shuffling buckets
|
||||
def set_current_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""
|
||||
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
|
||||
|
||||
Returns: example like this:
|
||||
|
||||
for i in range(batch_size):
|
||||
image_key = ... # whatever hashable
|
||||
image_keys.append(image_key)
|
||||
|
||||
image = ... # PIL Image
|
||||
img_tensor = self.image_transforms(img)
|
||||
images.append(img_tensor)
|
||||
|
||||
caption = ... # str
|
||||
input_ids = self.get_input_ids(caption)
|
||||
input_ids_list.append(input_ids)
|
||||
|
||||
captions.append(caption)
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
input_ids_list = torch.stack(input_ids_list, dim=0)
|
||||
example = {
|
||||
"images": images,
|
||||
"input_ids": input_ids_list,
|
||||
"captions": captions, # for debug_dataset
|
||||
"latents": None,
|
||||
"image_keys": image_keys, # for debug_dataset
|
||||
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
|
||||
}
|
||||
return example
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||
module = ".".join(args.dataset_class.split(".")[:-1])
|
||||
dataset_class = args.dataset_class.split(".")[-1]
|
||||
module = importlib.import_module(module)
|
||||
dataset_class = getattr(module, dataset_class)
|
||||
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
|
||||
return train_dataset_group
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region モジュール入れ替え部
|
||||
@@ -1941,59 +2012,73 @@ def get_git_revision_hash() -> str:
|
||||
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
replace_attentions_for_hypernetwork()
|
||||
# unet is not used currently, but it is here for future use
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
return
|
||||
# def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
# replace_attentions_for_hypernetwork()
|
||||
# # unet is not used currently, but it is here for future use
|
||||
# unet.enable_xformers_memory_efficient_attention()
|
||||
# return
|
||||
# if mem_eff_attn:
|
||||
# unet.set_attn_processor(FlashAttnProcessor())
|
||||
# elif xformers:
|
||||
# unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
|
||||
# def replace_unet_cross_attn_to_xformers():
|
||||
# print("CrossAttention.forward has been replaced to enable xformers.")
|
||||
# try:
|
||||
# import xformers.ops
|
||||
# except ImportError:
|
||||
# raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
# def forward_xformers(self, x, context=None, mask=None):
|
||||
# h = self.heads
|
||||
# q_in = self.to_q(x)
|
||||
|
||||
# context = default(context, x)
|
||||
# context = context.to(x.dtype)
|
||||
|
||||
# if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
# context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
# context_k = context_k.to(x.dtype)
|
||||
# context_v = context_v.to(x.dtype)
|
||||
# else:
|
||||
# context_k = context
|
||||
# context_v = context
|
||||
|
||||
# k_in = self.to_k(context_k)
|
||||
# v_in = self.to_v(context_v)
|
||||
|
||||
# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
# del q_in, k_in, v_in
|
||||
|
||||
# q = q.contiguous()
|
||||
# k = k.contiguous()
|
||||
# v = v.contiguous()
|
||||
# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
# out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
|
||||
# # diffusers 0.7.0~
|
||||
# out = self.to_out[0](out)
|
||||
# out = self.to_out[1](out)
|
||||
# return out
|
||||
|
||||
# diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
def replace_unet_modules(unet:UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||
if mem_eff_attn:
|
||||
unet.set_attn_processor(FlashAttnProcessor())
|
||||
print("Enable memory efficient attention for U-Net")
|
||||
unet.set_use_memory_efficient_attention(False, True)
|
||||
elif xformers:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
print("CrossAttention.forward has been replaced to enable xformers.")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
def forward_xformers(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
|
||||
context = default(context, x)
|
||||
context = context.to(x.dtype)
|
||||
|
||||
if hasattr(self, "hypernetwork") and self.hypernetwork is not None:
|
||||
context_k, context_v = self.hypernetwork.forward(x, context)
|
||||
context_k = context_k.to(x.dtype)
|
||||
context_v = context_v.to(x.dtype)
|
||||
else:
|
||||
context_k = context
|
||||
context_v = context
|
||||
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
||||
|
||||
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
||||
|
||||
# diffusers 0.7.0~
|
||||
out = self.to_out[0](out)
|
||||
out = self.to_out[1](out)
|
||||
return out
|
||||
|
||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
print("Enable xformers for U-Net")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
raise ImportError("No xformers / xformersがインストールされていないようです")
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
elif sdpa:
|
||||
print("Enable SDPA for U-Net")
|
||||
unet.set_use_sdpa(True)
|
||||
|
||||
"""
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
@@ -2242,6 +2327,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument("--sdpa", action="store_true", help="use sdpa for CrossAttention (requires PyTorch 2.0) / CrossAttentionにsdpaを使う(PyTorch 2.0が必要)")
|
||||
parser.add_argument(
|
||||
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
|
||||
)
|
||||
@@ -2428,6 +2514,11 @@ def verify_training_args(args: argparse.Namespace):
|
||||
if args.adaptive_noise_scale is not None and args.noise_offset is None:
|
||||
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
|
||||
|
||||
if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization:
|
||||
raise ValueError(
|
||||
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
|
||||
)
|
||||
|
||||
|
||||
def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
@@ -2506,7 +2597,6 @@ def add_dataset_arguments(
|
||||
default=1,
|
||||
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token_warmup_step",
|
||||
type=float,
|
||||
@@ -2514,6 +2604,13 @@ def add_dataset_arguments(
|
||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_class",
|
||||
type=str,
|
||||
default=None,
|
||||
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
|
||||
)
|
||||
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
@@ -2788,15 +2885,7 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.startswith("DAdapt".lower()):
|
||||
# DAdaptation family
|
||||
# check dadaptation is installed
|
||||
try:
|
||||
import dadaptation
|
||||
import dadaptation.experimental as experimental
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
|
||||
elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower():
|
||||
# check lr and lr_count, and print warning
|
||||
actual_lr = lr
|
||||
lr_count = 1
|
||||
@@ -2809,40 +2898,60 @@ def get_optimizer(args, trainable_params):
|
||||
|
||||
if actual_lr <= 0.1:
|
||||
print(
|
||||
f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}"
|
||||
f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}"
|
||||
)
|
||||
print("recommend option: lr=1.0 / 推奨は1.0です")
|
||||
if lr_count > 1:
|
||||
print(
|
||||
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
|
||||
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
|
||||
)
|
||||
|
||||
# set optimizer
|
||||
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
|
||||
optimizer_class = experimental.DAdaptAdamPreprint
|
||||
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdaGrad".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdaGrad
|
||||
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdam".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdan".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdan
|
||||
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdanIP".lower():
|
||||
optimizer_class = experimental.DAdaptAdanIP
|
||||
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptLion".lower():
|
||||
optimizer_class = dadaptation.DAdaptLion
|
||||
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptSGD".lower():
|
||||
optimizer_class = dadaptation.DAdaptSGD
|
||||
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
if optimizer_type.startswith("DAdapt".lower()):
|
||||
# DAdaptation family
|
||||
# check dadaptation is installed
|
||||
try:
|
||||
import dadaptation
|
||||
import dadaptation.experimental as experimental
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
# set optimizer
|
||||
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
|
||||
optimizer_class = experimental.DAdaptAdamPreprint
|
||||
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdaGrad".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdaGrad
|
||||
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdam".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdan".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdan
|
||||
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdanIP".lower():
|
||||
optimizer_class = experimental.DAdaptAdanIP
|
||||
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptLion".lower():
|
||||
optimizer_class = dadaptation.DAdaptLion
|
||||
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptSGD".lower():
|
||||
optimizer_class = dadaptation.DAdaptSGD
|
||||
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
else:
|
||||
# Prodigy
|
||||
# check Prodigy is installed
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError("No Prodigy / Prodigy がインストールされていないようです")
|
||||
|
||||
print(f"use Prodigy optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Adafactor".lower():
|
||||
# 引数を確認して適宜補正する
|
||||
@@ -3093,23 +3202,9 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=log_with,
|
||||
logging_dir=logging_dir,
|
||||
project_dir=logging_dir,
|
||||
)
|
||||
|
||||
# accelerateの互換性問題を解決する
|
||||
accelerator_0_15 = True
|
||||
try:
|
||||
accelerator.unwrap_model("dummy", True)
|
||||
print("Using accelerator 0.15.0 or above.")
|
||||
except TypeError:
|
||||
accelerator_0_15 = False
|
||||
|
||||
def unwrap_model(model):
|
||||
if accelerator_0_15:
|
||||
return accelerator.unwrap_model(model, True)
|
||||
return accelerator.unwrap_model(model)
|
||||
|
||||
return accelerator, unwrap_model
|
||||
return accelerator
|
||||
|
||||
|
||||
def prepare_dtype(args: argparse.Namespace):
|
||||
@@ -3146,11 +3241,26 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
print(
|
||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
||||
)
|
||||
raise ex
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
# Diffusers U-Net to original U-Net
|
||||
# TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう
|
||||
# print(f"unet config: {unet.config}")
|
||||
original_unet = UNet2DConditionModel(
|
||||
unet.config.sample_size,
|
||||
unet.config.attention_head_dim,
|
||||
unet.config.cross_attention_dim,
|
||||
unet.config.use_linear_projection,
|
||||
unet.config.upcast_attention,
|
||||
)
|
||||
original_unet.load_state_dict(unet.state_dict())
|
||||
unet = original_unet
|
||||
print("U-Net converted to original U-Net")
|
||||
|
||||
# VAEを読み込む
|
||||
if args.vae is not None:
|
||||
vae = model_util.load_vae(args.vae, weight_dtype)
|
||||
@@ -3580,6 +3690,7 @@ def sample_images(
|
||||
requires_safety_checker=False,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
pipeline.clip_skip = args.clip_skip # Pipelineのコンストラクタにckip_skipを追加できないので後から設定する
|
||||
pipeline.to(device)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
@@ -3769,4 +3880,4 @@ class collater_class:
|
||||
# set epoch and step
|
||||
dataset.set_current_epoch(self.current_epoch.value)
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
return examples[0]
|
||||
|
||||
141
networks/lora.py
141
networks/lora.py
@@ -19,7 +19,17 @@ class LoRAModule(torch.nn.Module):
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
@@ -61,6 +71,8 @@ class LoRAModule(torch.nn.Module):
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
@@ -68,18 +80,51 @@ class LoRAModule(torch.nn.Module):
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
if self.dropout:
|
||||
return (
|
||||
self.org_forward(x)
|
||||
+ self.lora_up(torch.nn.functional.dropout(self.lora_down(x), p=self.dropout)) * self.multiplier * self.scale
|
||||
)
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
lx = self.lora_down(x)
|
||||
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, dropout)
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||
|
||||
self.org_module_ref = [org_module] # 後から参照できるように
|
||||
self.enabled = True
|
||||
@@ -355,7 +400,7 @@ def parse_block_lr_kwargs(nw_kwargs):
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs):
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
@@ -395,6 +440,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
conv_block_dims = None
|
||||
conv_block_alphas = None
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoder,
|
||||
@@ -402,7 +455,9 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=dropout,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
block_dims=block_dims,
|
||||
@@ -679,6 +734,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
conv_lora_dim=None,
|
||||
conv_alpha=None,
|
||||
block_dims=None,
|
||||
@@ -706,18 +763,22 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
self.conv_alpha = conv_alpha
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
print(f"create LoRA network from block_dims, neuron dropout: p={self.dropout}")
|
||||
print(f"create LoRA network from block_dims")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
print(f"block_dims: {block_dims}")
|
||||
print(f"block_alphas: {block_alphas}")
|
||||
if conv_block_dims is not None:
|
||||
print(f"conv_block_dims: {conv_block_dims}")
|
||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, neuron dropout: p={self.dropout}")
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
if self.conv_lora_dim is not None:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
@@ -764,7 +825,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout)
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
)
|
||||
loras.append(lora)
|
||||
return loras, skipped
|
||||
|
||||
@@ -1056,3 +1126,46 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
|
||||
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
alphakeys = []
|
||||
norms = []
|
||||
keys_scaled = 0
|
||||
|
||||
state_dict = self.state_dict()
|
||||
for key in state_dict.keys():
|
||||
if "lora_down" in key and "weight" in key:
|
||||
downkeys.append(key)
|
||||
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||
|
||||
for i in range(len(downkeys)):
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
up = state_dict[upkeys[i]].to(device)
|
||||
alpha = state_dict[alphakeys[i]].to(device)
|
||||
dim = down.shape[0]
|
||||
scale = alpha / dim
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
updown *= scale
|
||||
|
||||
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||
desired = torch.clamp(norm, max=max_norm_value)
|
||||
ratio = desired.cpu() / norm.cpu()
|
||||
sqrt_ratio = ratio**0.5
|
||||
if ratio != 1:
|
||||
keys_scaled += 1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
@@ -5,6 +5,7 @@ ftfy==6.1.1
|
||||
albumentations==1.3.0
|
||||
opencv-python==4.7.0.68
|
||||
einops==0.6.0
|
||||
diffusers[torch]==0.17.0
|
||||
pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.35.0
|
||||
tensorboard==2.10.1
|
||||
@@ -14,13 +15,12 @@ altair==4.2.2
|
||||
easygui==0.98.3
|
||||
toml==0.10.2
|
||||
voluptuous==0.13.1
|
||||
huggingface-hub==0.13.3
|
||||
# for BLIP captioning
|
||||
requests==2.28.2
|
||||
timm==0.6.12
|
||||
fairscale==0.4.13
|
||||
# for WD14 captioning
|
||||
# tensorflow<2.11
|
||||
tensorflow==2.10.1
|
||||
huggingface-hub==0.13.3
|
||||
# tensorflow==2.10.1
|
||||
# for kohya_ss library
|
||||
.
|
||||
|
||||
63
train_db.py
63
train_db.py
@@ -23,8 +23,10 @@ import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
# perlin_noise,
|
||||
@@ -41,26 +43,30 @@ def train(args):
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -90,7 +96,7 @@ def train(args):
|
||||
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
||||
)
|
||||
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
@@ -114,7 +120,7 @@ def train(args):
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
@@ -237,6 +243,7 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
@@ -324,6 +331,8 @@ def train(args):
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -364,15 +373,15 @@ def train(args):
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
accelerator.unwrap_model(text_encoder),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
@@ -412,8 +421,8 @@ def train(args):
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
accelerator.unwrap_model(text_encoder),
|
||||
accelerator.unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
@@ -421,8 +430,8 @@ def train(args):
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
111
train_network.py
111
train_network.py
@@ -27,9 +27,10 @@ import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
max_norm,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
|
||||
@@ -55,7 +56,7 @@ def generate_step_logs(
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
|
||||
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value of unet.
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
@@ -65,7 +66,7 @@ def generate_step_logs(
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()):
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
@@ -90,42 +91,50 @@ def train(args):
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
||||
if use_user_config:
|
||||
print(f"Loading dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
||||
if use_user_config:
|
||||
print(f"Loading dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -148,7 +157,7 @@ def train(args):
|
||||
|
||||
# acceleratorを準備する
|
||||
print("preparing accelerator")
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
@@ -158,7 +167,7 @@ def train(args):
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
|
||||
# 差分追加学習のためにモデルを読み込む
|
||||
import sys
|
||||
@@ -211,13 +220,18 @@ def train(args):
|
||||
else:
|
||||
# LyCORIS will work with this...
|
||||
network = network_module.create_network(
|
||||
1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, dropout=args.network_dropout, **net_kwargs
|
||||
1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, neuron_dropout=args.network_dropout, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||
print(
|
||||
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
|
||||
)
|
||||
args.scale_weight_norms = False
|
||||
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = not args.network_train_unet_only
|
||||
@@ -315,7 +329,7 @@ def train(args):
|
||||
|
||||
network.prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if not cache_latents:
|
||||
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -552,6 +566,8 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
@@ -655,6 +671,8 @@ def train(args):
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -668,7 +686,9 @@ def train(args):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms, accelerator.device)
|
||||
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
@@ -687,7 +707,7 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, unwrap_model(network), global_step, epoch)
|
||||
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
@@ -709,7 +729,7 @@ def train(args):
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**max_mean_logs)
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||
@@ -729,7 +749,7 @@ def train(args):
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
if is_main_process and saving:
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||
save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1)
|
||||
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
|
||||
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
@@ -747,7 +767,7 @@ def train(args):
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
|
||||
if is_main_process:
|
||||
network = unwrap_model(network)
|
||||
network = accelerator.unwrap_model(network)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
@@ -837,7 +857,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
nargs="*",
|
||||
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,13 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
@@ -89,7 +95,7 @@ def train(args):
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
@@ -144,43 +150,46 @@ def train(args):
|
||||
accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
||||
if args.dataset_config is not None:
|
||||
accelerator.print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
accelerator.print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
||||
if args.dataset_config is not None:
|
||||
accelerator.print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
accelerator.print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
use_dreambooth_method = args.in_json is None
|
||||
if use_dreambooth_method:
|
||||
accelerator.print("Use DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
accelerator.print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
use_dreambooth_method = args.in_json is None
|
||||
if use_dreambooth_method:
|
||||
accelerator.print("Use DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -222,7 +231,7 @@ def train(args):
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
@@ -282,7 +291,7 @@ def train(args):
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.requires_grad_(True)
|
||||
@@ -335,6 +344,7 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
@@ -409,12 +419,14 @@ def train(args):
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
@@ -428,7 +440,7 @@ def train(args):
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
with torch.no_grad():
|
||||
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
@@ -445,7 +457,9 @@ def train(args):
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
updated_embs = (
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
)
|
||||
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, updated_embs, global_step, epoch)
|
||||
@@ -461,7 +475,7 @@ def train(args):
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
@@ -481,7 +495,7 @@ def train(args):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
@@ -505,7 +519,7 @@ def train(args):
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
import library
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
@@ -20,7 +21,14 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
import library.original_unet as original_unet
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
|
||||
imagenet_templates_small = [
|
||||
@@ -88,6 +96,9 @@ def train(args):
|
||||
print(
|
||||
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||
)
|
||||
assert (
|
||||
args.dataset_class is None
|
||||
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
@@ -98,7 +109,7 @@ def train(args):
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
@@ -256,10 +267,10 @@ def train(args):
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
original_unet.UNet2DConditionModel.forward = unet_forward_XTI
|
||||
original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||
original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
@@ -319,7 +330,7 @@ def train(args):
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.requires_grad_(True)
|
||||
@@ -372,6 +383,7 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
@@ -451,11 +463,13 @@ def train(args):
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
loss = loss * loss_weights
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -470,7 +484,7 @@ def train(args):
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
with torch.no_grad():
|
||||
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
@@ -487,7 +501,13 @@ def train(args):
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||
updated_embs = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[token_ids_XTI]
|
||||
.data.detach()
|
||||
.clone()
|
||||
)
|
||||
|
||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||
save_model(ckpt_name, updated_embs, global_step, epoch)
|
||||
@@ -503,7 +523,9 @@ def train(args):
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()): # tracking d*lr value
|
||||
if (
|
||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||
): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
@@ -523,7 +545,7 @@ def train(args):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||
@@ -548,7 +570,7 @@ def train(args):
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user