diff --git a/library/strategy_lumina.py b/library/strategy_lumina.py index 392d6594..964d9f7a 100644 --- a/library/strategy_lumina.py +++ b/library/strategy_lumina.py @@ -25,20 +25,26 @@ GEMMA_ID = "google/gemma-2-2b" class LuminaTokenizeStrategy(TokenizeStrategy): def __init__( - self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None + self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None ) -> None: self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained( GEMMA_ID, cache_dir=tokenizer_cache_dir ) self.tokenizer.padding_side = "right" + if system_prompt is None: + system_prompt = "" + system_prompt_special_token = "" + system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else "" + self.system_prompt = system_prompt + if max_length is None: self.max_length = 256 else: self.max_length = max_length def tokenize( - self, text: Union[str, List[str]] + self, text: Union[str, List[str]], is_negative: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -49,6 +55,12 @@ class LuminaTokenizeStrategy(TokenizeStrategy): token input ids, attention_masks """ text = [text] if isinstance(text, str) else text + + # In training, we always add system prompt (is_negative=False) + if not is_negative: + # Add system prompt to the beginning of each text + text = [self.system_prompt + t for t in text] + encodings = self.tokenizer( text, max_length=self.max_length, diff --git a/lumina_train.py b/lumina_train.py index 4b733c9e..0a91f4a0 100644 --- a/lumina_train.py +++ b/lumina_train.py @@ -166,7 +166,7 @@ def train(args): ) ) strategy_base.TokenizeStrategy.set_strategy( - strategy_lumina.LuminaTokenizeStrategy() + strategy_lumina.LuminaTokenizeStrategy(args.system_prompt) ) train_dataset_group.set_current_strategies() @@ -221,7 +221,7 @@ def train(args): gemma2_max_token_length = args.gemma2_max_token_length lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy( - gemma2_max_token_length + args.system_prompt, gemma2_max_token_length ) strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy) @@ -266,19 +266,17 @@ def train(args): strategy_base.TextEncodingStrategy.get_strategy() ) - system_prompt_special_token = "" - system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: - for p in [ - system_prompt + prompt_dict.get("prompt", ""), + for i, p in enumerate([ + prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), - ]: + ]): if p not in sample_prompts_te_outputs: logger.info(f"cache Text Encoder outputs for prompt: {p}") - tokens_and_masks = lumina_tokenize_strategy.tokenize(p) + tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt sample_prompts_te_outputs[p] = ( text_encoding_strategy.encode_tokens( lumina_tokenize_strategy, diff --git a/lumina_train_network.py b/lumina_train_network.py index 037ddac6..b08e3143 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -86,7 +86,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model def get_tokenize_strategy(self, args): - return strategy_lumina.LuminaTokenizeStrategy(args.gemma2_max_token_length, args.tokenizer_cache_dir) + return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy): return [tokenize_strategy.tokenizer] @@ -156,25 +156,20 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy) assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy) - system_prompt_special_token = "" - system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else "" sample_prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs with accelerator.autocast(), torch.no_grad(): for prompt_dict in sample_prompts: prompts = [ - system_prompt + prompt_dict.get("prompt", ""), + prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", ""), ] for i, prompt in enumerate(prompts): - # Add system prompt only to positive prompt - if i == 0: - prompt = system_prompt + prompt if prompt in sample_prompts_te_outputs: continue logger.info(f"cache Text Encoder outputs for prompt: {prompt}") - tokens_and_masks = tokenize_strategy.tokenize(prompt) + tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens( tokenize_strategy, text_encoders, diff --git a/tests/library/test_strategy_lumina.py b/tests/library/test_strategy_lumina.py index 9bb0edf7..d77d2738 100644 --- a/tests/library/test_strategy_lumina.py +++ b/tests/library/test_strategy_lumina.py @@ -41,7 +41,7 @@ class SimpleMockGemma2Model: def test_lumina_tokenize_strategy(): # Test default initialization try: - tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None) except OSError as e: # If the tokenizer is not found (due to gated repo), we can skip the test print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") @@ -67,7 +67,7 @@ def test_lumina_tokenize_strategy(): def test_lumina_text_encoding_strategy(): # Create strategies try: - tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None) except OSError as e: # If the tokenizer is not found (due to gated repo), we can skip the test print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}") @@ -148,7 +148,7 @@ def test_lumina_text_encoder_outputs_caching_strategy(): # Create mock strategies and model try: - tokenize_strategy = LuminaTokenizeStrategy(max_length=None) + tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None) except OSError as e: # If the tokenizer is not found (due to gated repo), we can skip the test print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")