Merge pull request #1067 from kohya-ss/dev

Dev
This commit is contained in:
Kohya S
2024-01-23 21:04:16 +09:00
committed by GitHub
22 changed files with 795 additions and 597 deletions

View File

@@ -249,6 +249,64 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Jan 23, 2024 / 2024/1/23: v0.8.2
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
- PyTorch 2.1 or later is required.
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
- This is an experimental option and may be removed or changed in the future.
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
- Fixed a bug in multi-GPU Textual Inversion training.
- 実験的 LoRA等の学習スクリプトで、ベースモデルU-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
- `train_network.py` または `sdxl_train_network.py``--fp8_base` を指定してください。
- PyTorch 2.1 以降が必要です。
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
- また、五段階の状態を用意し、それぞれ `0.2``0.4``0.6``0.8``1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
- `.toml` ファイルで `[[datasets]]``network_multiplier` を指定してください。
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
```toml
[general]
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = 1.0
... subset settings ...
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = -1.0
... subset settings ...
```
### Jan 17, 2024 / 2024/1/17: v0.8.1
- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`).

View File

@@ -1,11 +1,7 @@
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

View File

@@ -11,15 +11,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

View File

@@ -66,15 +66,10 @@ import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,

File diff suppressed because it is too large Load Diff

24
library/ipex_interop.py Normal file
View File

@@ -0,0 +1,24 @@
import torch
def init_ipex():
"""
Try to import `intel_extension_for_pytorch`, and apply
the hijacks using `library.ipex.ipex_init`.
If IPEX is not installed, this function does nothing.
"""
try:
import intel_extension_for_pytorch as ipex # noqa
except ImportError:
return
try:
from library.ipex import ipex_init
if torch.xpu.is_available():
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
except Exception as e:
print("failed to initialize ipex:", e)

View File

@@ -5,15 +5,9 @@ import math
import os
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
init_ipex()
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel

View File

@@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if up1 is not None:
uncond_pool = up1
dtype = self.unet.dtype
unet_dtype = self.unet.dtype
dtype = unet_dtype
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
dtype = torch.float16
self.unet.to(dtype)
# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):
@@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if is_cancelled_callback is not None and is_cancelled_callback():
return None
self.unet.to(unet_dtype)
return latents
def latents_to_image(self, latents):

View File

@@ -558,6 +558,7 @@ class BaseDataset(torch.utils.data.Dataset):
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
max_token_length: int,
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
) -> None:
super().__init__()
@@ -567,6 +568,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.max_token_length = max_token_length
# width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution
self.network_multiplier = network_multiplier
self.debug_dataset = debug_dataset
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
@@ -1106,7 +1108,9 @@ class BaseDataset(torch.utils.data.Dataset):
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
loss_weights.append(
self.prior_loss_weight if image_info.is_reg else 1.0
) # in case of fine tuning, is_reg is always False
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
@@ -1272,6 +1276,8 @@ class BaseDataset(torch.utils.data.Dataset):
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
example["flippeds"] = flippeds
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
@@ -1346,15 +1352,16 @@ class DreamBoothDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -1520,14 +1527,15 @@ class FineTuningDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
self.batch_size = batch_size
@@ -1724,14 +1732,15 @@ class ControlNetDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
debug_dataset: float,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
db_subsets = []
for subset in subsets:
@@ -2039,6 +2048,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
print(
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
)
if "network_multipliers" in example:
print(f"network multiplier: {example['network_multipliers'][j]}")
if show_input_ids:
print(f"input ids: {iid}")
@@ -2105,8 +2116,8 @@ def glob_images_pathlib(dir_path, recursive):
class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass
@@ -2850,14 +2861,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
parser.add_argument(
"--dynamo_backend",
type=str,
default="inductor",
"--dynamo_backend",
type=str,
default="inductor",
# available backends:
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
# https://pytorch.org/docs/stable/torch.compiler.html
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor"
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
@@ -2904,6 +2915,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
parser.add_argument(
"--ddp_timeout",
type=int,
@@ -3886,7 +3898,7 @@ def prepare_accelerator(args: argparse.Namespace):
os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)
# torch.compile のオプション。 NO の場合は torch.compile は使わない
dynamo_backend = "NO"
if args.torch_compile:

View File

@@ -43,6 +43,9 @@ def svd(
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
):
def str_to_dtype(p):
if p == "float":
@@ -57,28 +60,51 @@ def svd(
if v_parameterization is None:
v_parameterization = v2
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision)
work_device = "cpu"
# load models
if not sdxl:
print(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype is not None:
text_encoder_o = text_encoder_o.to(load_dtype)
unet_o = unet_o.to(load_dtype)
print(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype is not None:
text_encoder_t = text_encoder_t.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to if load_original_model_to else "cpu"
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
print(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype is not None:
text_encoder_o1 = text_encoder_o1.to(load_dtype)
text_encoder_o2 = text_encoder_o2.to(load_dtype)
unet_o = unet_o.to(load_dtype)
print(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype is not None:
text_encoder_t1 = text_encoder_t1.to(load_dtype)
text_encoder_t2 = text_encoder_t2.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# create LoRA network to extract weights: Use dim (rank) as alpha
@@ -100,38 +126,54 @@ def svd(
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diff = diff.float()
diffs[lora_name] = diff
# clear target Text Encoder to save memory
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
diffs = {}
diffs = {} # clear diffs
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
if args.device:
diff = diff.to(args.device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
diffs[lora_name] = diff
# clear LoRA network, target U-Net to save memory
del lora_network_o
del lora_network_t
del unet_t
# make LoRA with svd
print("calculating by svd")
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
if args.device:
mat = mat.to(args.device)
mat = mat.to(torch.float) # calc by float
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
@@ -171,8 +213,8 @@ def svd(
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
U = U.to("cpu").contiguous()
Vh = Vh.to("cpu").contiguous()
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh)
@@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
)
parser.add_argument(
"--load_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
)
parser.add_argument(
"--save_precision",
type=str,
@@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--load_original_model_to",
type=str,
default=None,
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
parser.add_argument(
"--load_tuned_model_to",
type=str,
default=None,
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
return parser

View File

@@ -18,15 +18,10 @@ import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,

View File

@@ -9,13 +9,11 @@ import random
from einops import repeat
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler

View File

@@ -11,15 +11,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util

View File

@@ -14,13 +14,11 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate

View File

@@ -11,13 +11,11 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel

View File

@@ -1,15 +1,10 @@
import argparse
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network

View File

@@ -3,13 +3,9 @@ import os
import regex
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util

View File

@@ -12,15 +12,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel

View File

@@ -12,15 +12,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

View File

@@ -14,15 +14,10 @@ from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
@@ -310,6 +305,7 @@ class NetworkTrainer:
)
if network is None:
return
network_has_multiplier = hasattr(network, "set_multiplier")
if hasattr(network, "prepare_network"):
network.prepare_network(args)
@@ -390,16 +386,33 @@ class NetworkTrainer:
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.requires_grad_(False)
unet.to(dtype=weight_dtype)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
@@ -421,9 +434,6 @@ class NetworkTrainer:
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
@@ -754,7 +764,17 @@ class NetworkTrainer:
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
# get multiplier for each sample
if network_has_multiplier:
multipliers = batch["network_multipliers"]
# if all multipliers are same, use single multiplier
if torch.all(multipliers == multipliers[0]):
multipliers = multipliers[0].item()
else:
raise NotImplementedError("multipliers for each sample is not supported yet")
# print(f"set multiplier: {multipliers}")
network.set_multiplier(multipliers)
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
@@ -778,10 +798,24 @@ class NetworkTrainer:
args, noise_scheduler, latents
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
if args.v_parameterization:
@@ -808,10 +842,11 @@ class NetworkTrainer:
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()

View File

@@ -8,15 +8,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
@@ -505,7 +500,7 @@ class TextualInversionTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -730,14 +725,13 @@ class TextualInversionTrainer:
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = accelerator.unwrap_model(text_encoder)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
accelerator.end_training()
if args.save_state and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)

View File

@@ -8,13 +8,11 @@ from multiprocessing import Value
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler