diff --git a/fine_tune.py b/fine_tune.py index 62a545a1..fd63385b 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -366,22 +366,17 @@ def train(args): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning if args.weighted_captions: - # TODO move to strategy_sd.py - encoder_hidden_states = get_weighted_text_embeddings( - tokenize_strategy.tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = text_encoding_strategy.encode_tokens( tokenize_strategy, [text_encoder], [input_ids] )[0] - if args.full_fp16: - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index aaf77b8d..dc3887c3 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -363,9 +363,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # ) # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました") - assert ( - not hasattr(args, "weighted_captions") or not args.weighted_captions - ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" + # assert ( + # not hasattr(args, "weighted_captions") or not args.weighted_captions + # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" if supportTextEncoderCaching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: diff --git a/library/strategy_base.py b/library/strategy_base.py index 7981bd0b..2bff4178 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -323,12 +323,18 @@ class TextEncoderOutputsCachingStrategy: _strategy = None # strategy instance: actual strategy class def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, ) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check self._is_partial = is_partial + self._is_weighted = is_weighted @classmethod def set_strategy(cls, strategy): @@ -352,6 +358,10 @@ class TextEncoderOutputsCachingStrategy: def is_partial(self): return self._is_partial + @property + def is_weighted(self): + return self._is_weighted + def get_outputs_npz_path(self, image_abs_path: str) -> str: raise NotImplementedError diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 83ffaa31..4e7931fd 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -40,6 +40,16 @@ class SdTokenizeStrategy(TokenizeStrategy): text = [text] if isinstance(text, str) else text return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]: + text = [text] if isinstance(text, str) else text + tokens_list = [] + weights_list = [] + for t in text: + tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) + tokens_list.append(tokens) + weights_list.append(weights) + return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] + class SdTextEncodingStrategy(TextEncodingStrategy): def __init__(self, clip_skip: Optional[int] = None) -> None: @@ -58,6 +68,8 @@ class SdTextEncodingStrategy(TextEncodingStrategy): model_max_length = sd_tokenize_strategy.tokenizer.model_max_length tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 + tokens = tokens.to(text_encoder.device) + if self.clip_skip is None: encoder_hidden_states = text_encoder(tokens)[0] else: @@ -93,6 +105,30 @@ class SdTextEncodingStrategy(TextEncodingStrategy): return [encoder_hidden_states] + def encode_tokens_with_weights( + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], + ) -> List[torch.Tensor]: + encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] + + weights = weights_list[0].to(encoder_hidden_states.device) + + # apply weights + if weights.shape[1] == 1: # no max_token_length + # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) + encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + else: + # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) + for i in range(weights.shape[1]): + encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ + :, i, 1:-1 + ].unsqueeze(-1) + + return [encoder_hidden_states] + class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. diff --git a/library/strategy_sdxl.py b/library/strategy_sdxl.py index 6650e2b4..6b3e2afa 100644 --- a/library/strategy_sdxl.py +++ b/library/strategy_sdxl.py @@ -42,16 +42,16 @@ class SdxlTokenizeStrategy(TokenizeStrategy): tokens1_list, tokens2_list = [], [] weights1_list, weights2_list = [], [] for t in text: - tokens1, weights1 = self._get_weighted_input_ids(self.tokenizer1, t, self.max_length) - tokens2, weights2 = self._get_weighted_input_ids(self.tokenizer2, t, self.max_length) + tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True) + tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True) tokens1_list.append(tokens1) tokens2_list.append(tokens2) weights1_list.append(weights1) weights2_list.append(weights2) - return (torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)), ( + return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [ torch.stack(weights1_list, dim=0), torch.stack(weights2_list, dim=0), - ) + ] class SdxlTextEncodingStrategy(TextEncodingStrategy): @@ -193,20 +193,28 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy): return [hidden_states1, hidden_states2, pool2] def encode_tokens_with_weights( - self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor] + self, + tokenize_strategy: TokenizeStrategy, + models: List[Any], + tokens_list: List[torch.Tensor], + weights_list: List[torch.Tensor], ) -> List[torch.Tensor]: - hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens) + hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list) + + weights_list = [weights.to(hidden_states1.device) for weights in weights_list] # apply weights - if weights[0].shape[1] == 1: # no max_token_length + if weights_list[0].shape[1] == 1: # no max_token_length # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) - hidden_states1 = hidden_states1 * weights[0].squeeze(1).unsqueeze(2) - hidden_states2 = hidden_states2 * weights[1].squeeze(1).unsqueeze(2) + hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2) + hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2) else: # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) - for weight, hidden_states in zip(weights, [hidden_states1, hidden_states2]): + for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]): for i in range(weight.shape[1]): - hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[:, i, 1:-1] + hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[ + :, i, 1:-1 + ].unsqueeze(-1) return [hidden_states1, hidden_states2, pool2] @@ -215,9 +223,14 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz" def __init__( - self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False + self, + cache_to_disk: bool, + batch_size: int, + skip_disk_cache_validity_check: bool, + is_partial: bool = False, + is_weighted: bool = False, ) -> None: - super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial) + super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted) def get_outputs_npz_path(self, image_abs_path: str) -> str: return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX @@ -253,11 +266,19 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy captions = [info.caption for info in infos] - tokens1, tokens2 = tokenize_strategy.tokenize(captions) - with torch.no_grad(): - hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( - tokenize_strategy, models, [tokens1, tokens2] - ) + if self.is_weighted: + tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, models, tokens_list, weights_list + ) + else: + tokens1, tokens2 = tokenize_strategy.tokenize(captions) + with torch.no_grad(): + hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens( + tokenize_strategy, models, [tokens1, tokens2] + ) + if hidden_state1.dtype == torch.bfloat16: hidden_state1 = hidden_state1.float() if hidden_state2.dtype == torch.bfloat16: diff --git a/sdxl_train.py b/sdxl_train.py index 320169d7..aeff9c46 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -321,7 +321,7 @@ def train(args): if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 4d6e3f18..20e32155 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -79,7 +79,9 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: - return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False) + return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy( + args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions + ) else: return None diff --git a/train_db.py b/train_db.py index a5d520b1..e49a7e70 100644 --- a/train_db.py +++ b/train_db.py @@ -356,21 +356,17 @@ def train(args): # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenize_strategy.tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) + input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) + encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights( + tokenize_strategy, [text_encoder], input_ids_list, weights_list + )[0] else: input_ids = batch["input_ids_list"][0].to(accelerator.device) encoder_hidden_states = text_encoding_strategy.encode_tokens( tokenize_strategy, [text_encoder], [input_ids] )[0] - if args.full_fp16: - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + if args.full_fp16: + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified