fix: update system prompt handling

This commit is contained in:
Kohya S
2025-06-29 22:21:48 +09:00
parent 078ee28a94
commit 6731d8a57f
4 changed files with 26 additions and 21 deletions

View File

@@ -25,20 +25,26 @@ GEMMA_ID = "google/gemma-2-2b"
class LuminaTokenizeStrategy(TokenizeStrategy):
def __init__(
self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
) -> None:
self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
GEMMA_ID, cache_dir=tokenizer_cache_dir
)
self.tokenizer.padding_side = "right"
if system_prompt is None:
system_prompt = ""
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else ""
self.system_prompt = system_prompt
if max_length is None:
self.max_length = 256
else:
self.max_length = max_length
def tokenize(
self, text: Union[str, List[str]]
self, text: Union[str, List[str]], is_negative: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@@ -49,6 +55,12 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
token input ids, attention_masks
"""
text = [text] if isinstance(text, str) else text
# In training, we always add system prompt (is_negative=False)
if not is_negative:
# Add system prompt to the beginning of each text
text = [self.system_prompt + t for t in text]
encodings = self.tokenizer(
text,
max_length=self.max_length,

View File

@@ -166,7 +166,7 @@ def train(args):
)
)
strategy_base.TokenizeStrategy.set_strategy(
strategy_lumina.LuminaTokenizeStrategy()
strategy_lumina.LuminaTokenizeStrategy(args.system_prompt)
)
train_dataset_group.set_current_strategies()
@@ -221,7 +221,7 @@ def train(args):
gemma2_max_token_length = args.gemma2_max_token_length
lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(
gemma2_max_token_length
args.system_prompt, gemma2_max_token_length
)
strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy)
@@ -266,19 +266,17 @@ def train(args):
strategy_base.TextEncodingStrategy.get_strategy()
)
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [
system_prompt + prompt_dict.get("prompt", ""),
for i, p in enumerate([
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]:
]):
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = lumina_tokenize_strategy.tokenize(p)
tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt
sample_prompts_te_outputs[p] = (
text_encoding_strategy.encode_tokens(
lumina_tokenize_strategy,

View File

@@ -86,7 +86,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
def get_tokenize_strategy(self, args):
return strategy_lumina.LuminaTokenizeStrategy(args.gemma2_max_token_length, args.tokenizer_cache_dir)
return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
return [tokenize_strategy.tokenizer]
@@ -156,25 +156,20 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
sample_prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in sample_prompts:
prompts = [
system_prompt + prompt_dict.get("prompt", ""),
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]
for i, prompt in enumerate(prompts):
# Add system prompt only to positive prompt
if i == 0:
prompt = system_prompt + prompt
if prompt in sample_prompts_te_outputs:
continue
logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt
sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,

View File

@@ -41,7 +41,7 @@ class SimpleMockGemma2Model:
def test_lumina_tokenize_strategy():
# Test default initialization
try:
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
@@ -67,7 +67,7 @@ def test_lumina_tokenize_strategy():
def test_lumina_text_encoding_strategy():
# Create strategies
try:
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
@@ -148,7 +148,7 @@ def test_lumina_text_encoder_outputs_caching_strategy():
# Create mock strategies and model
try:
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")