diff --git a/train_control_net.py b/train_control_net.py index 97cd1ebb..c12693ba 100644 --- a/train_control_net.py +++ b/train_control_net.py @@ -12,7 +12,7 @@ import toml from tqdm import tqdm import torch -from library import deepspeed_utils +from library import deepspeed_utils, strategy_base, strategy_sd from library.device_utils import init_ipex, clean_memory_on_device init_ipex() @@ -73,7 +73,14 @@ def train(args): args.seed = random.randint(0, 2**32) set_seed(args.seed) - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + tokenizer = tokenize_strategy.tokenizer + # prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization. + latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy( + True, args.cache_latents_to_disk, args.vae_batch_size, False + ) + strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # データセットを準備する blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) @@ -100,7 +107,7 @@ def train(args): ] } - blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) @@ -243,12 +250,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents( - vae, - args.vae_batch_size, - args.cache_latents_to_disk, - accelerator.is_main_process, - ) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) @@ -267,6 +269,7 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 + train_dataset_group.set_current_strategies() n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers train_dataloader = torch.utils.data.DataLoader( @@ -451,7 +454,7 @@ def train(args): latents = latents * 0.18215 b_size = latents.shape[0] - input_ids = batch["input_ids"].to(accelerator.device) + input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents