mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix: use strategy for tokenizer and latent caching
This commit is contained in:
@@ -12,7 +12,7 @@ import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import deepspeed_utils
|
||||
from library import deepspeed_utils, strategy_base, strategy_sd
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -73,7 +73,14 @@ def train(args):
|
||||
args.seed = random.randint(0, 2**32)
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
tokenizer = tokenize_strategy.tokenizer
|
||||
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
True, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
)
|
||||
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
@@ -100,7 +107,7 @@ def train(args):
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
@@ -243,12 +250,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
vae,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -267,6 +269,7 @@ def train(args):
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
train_dataset_group.set_current_strategies()
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
@@ -451,7 +454,7 @@ def train(args):
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
||||
Reference in New Issue
Block a user