fix: use strategy for tokenizer and latent caching

This commit is contained in:
Kohya S
2025-08-16 22:03:52 +09:00
parent 6f24bce7cc
commit f61c442f0b

View File

@@ -12,7 +12,7 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library import deepspeed_utils, strategy_base, strategy_sd
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -73,7 +73,14 @@ def train(args):
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizer = tokenize_strategy.tokenizer
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
@@ -100,7 +107,7 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
@@ -243,12 +250,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -267,6 +269,7 @@ def train(args):
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
train_dataset_group.set_current_strategies()
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -451,7 +454,7 @@ def train(args):
latents = latents * 0.18215
b_size = latents.shape[0]
input_ids = batch["input_ids"].to(accelerator.device)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents