Better implementation for te autocast (#895)

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* Better cache TE and TE lr

* Fix with list

* Add timeout settings

* Fix arg style
This commit is contained in:
Kohaku-Blueleaf
2023-10-28 14:49:59 +08:00
committed by GitHub
parent 202f2c3292
commit 1cefb2a753
4 changed files with 41 additions and 30 deletions

View File

@@ -3,6 +3,7 @@
import argparse
import ast
import asyncio
import datetime
import importlib
import json
import pathlib
@@ -18,7 +19,7 @@ from typing import (
Tuple,
Union,
)
from accelerate import Accelerator
from accelerate import Accelerator, InitProcessGroupKwargs
import gc
import glob
import math
@@ -2855,6 +2856,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument(
"--ddp_timeout", type=int, default=30, help="DDP timeout (min) / DDPのタイムアウト(min)",
)
parser.add_argument(
"--clip_skip",
type=int,
@@ -3786,6 +3790,7 @@ def prepare_accelerator(args: argparse.Namespace):
mixed_precision=args.mixed_precision,
log_with=log_with,
project_dir=logging_dir,
kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))],
)
return accelerator

View File

@@ -287,6 +287,8 @@ def train(args):
training_models.append(text_encoder2)
# set require_grad=True later
else:
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder2.requires_grad_(False)
text_encoder1.eval()
@@ -295,7 +297,7 @@ def train(args):
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad():
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
@@ -315,25 +317,23 @@ def train(args):
m.requires_grad_(True)
if block_lrs is None:
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
# calculate number of trainable parameters
n_params = 0
for p in params:
n_params += p.numel()
params_to_optimize = [
{"params": list(training_models[0].parameters()), "lr": args.learning_rate},
]
else:
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({
"params": list(m.parameters()),
"lr": args.learning_rate_te or args.learning_rate
})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
@@ -396,8 +396,6 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
(unet,) = train_util.transform_models_if_DDP([unet])
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
@@ -728,6 +726,7 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder")
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")

View File

@@ -70,14 +70,16 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
if torch.cuda.is_available():
torch.cuda.empty_cache()
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)

View File

@@ -109,6 +109,9 @@ class NetworkTrainer:
def is_text_encoder_outputs_cached(self, args):
return False
def is_train_text_encoder(self, args):
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
@@ -310,7 +313,7 @@ class NetworkTrainer:
args.scale_weight_norms = False
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
@@ -403,6 +406,8 @@ class NetworkTrainer:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
elif train_text_encoder:
if len(text_encoders) > 1:
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -767,7 +772,7 @@ class NetworkTrainer:
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
with torch.set_grad_enabled(train_text_encoder):
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(