mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix: update system prompt handling
This commit is contained in:
@@ -25,20 +25,26 @@ GEMMA_ID = "google/gemma-2-2b"
|
||||
|
||||
class LuminaTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(
|
||||
self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
|
||||
self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
|
||||
) -> None:
|
||||
self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
|
||||
GEMMA_ID, cache_dir=tokenizer_cache_dir
|
||||
)
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
if system_prompt is None:
|
||||
system_prompt = ""
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else ""
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = 256
|
||||
else:
|
||||
self.max_length = max_length
|
||||
|
||||
def tokenize(
|
||||
self, text: Union[str, List[str]]
|
||||
self, text: Union[str, List[str]], is_negative: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@@ -49,6 +55,12 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
|
||||
token input ids, attention_masks
|
||||
"""
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
# In training, we always add system prompt (is_negative=False)
|
||||
if not is_negative:
|
||||
# Add system prompt to the beginning of each text
|
||||
text = [self.system_prompt + t for t in text]
|
||||
|
||||
encodings = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
|
||||
@@ -166,7 +166,7 @@ def train(args):
|
||||
)
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(
|
||||
strategy_lumina.LuminaTokenizeStrategy()
|
||||
strategy_lumina.LuminaTokenizeStrategy(args.system_prompt)
|
||||
)
|
||||
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -221,7 +221,7 @@ def train(args):
|
||||
gemma2_max_token_length = args.gemma2_max_token_length
|
||||
|
||||
lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(
|
||||
gemma2_max_token_length
|
||||
args.system_prompt, gemma2_max_token_length
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy)
|
||||
|
||||
@@ -266,19 +266,17 @@ def train(args):
|
||||
strategy_base.TextEncodingStrategy.get_strategy()
|
||||
)
|
||||
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
for p in [
|
||||
system_prompt + prompt_dict.get("prompt", ""),
|
||||
for i, p in enumerate([
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]:
|
||||
]):
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_and_masks = lumina_tokenize_strategy.tokenize(p)
|
||||
tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt
|
||||
sample_prompts_te_outputs[p] = (
|
||||
text_encoding_strategy.encode_tokens(
|
||||
lumina_tokenize_strategy,
|
||||
|
||||
@@ -86,7 +86,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_lumina.LuminaTokenizeStrategy(args.gemma2_max_token_length, args.tokenizer_cache_dir)
|
||||
return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
|
||||
return [tokenize_strategy.tokenizer]
|
||||
@@ -156,25 +156,20 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
|
||||
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)
|
||||
|
||||
system_prompt_special_token = "<Prompt Start>"
|
||||
system_prompt = f"{args.system_prompt} {system_prompt_special_token} " if args.system_prompt else ""
|
||||
sample_prompts = train_util.load_prompts(args.sample_prompts)
|
||||
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
for prompt_dict in sample_prompts:
|
||||
prompts = [
|
||||
system_prompt + prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("prompt", ""),
|
||||
prompt_dict.get("negative_prompt", ""),
|
||||
]
|
||||
for i, prompt in enumerate(prompts):
|
||||
# Add system prompt only to positive prompt
|
||||
if i == 0:
|
||||
prompt = system_prompt + prompt
|
||||
if prompt in sample_prompts_te_outputs:
|
||||
continue
|
||||
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt
|
||||
sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
text_encoders,
|
||||
|
||||
@@ -41,7 +41,7 @@ class SimpleMockGemma2Model:
|
||||
def test_lumina_tokenize_strategy():
|
||||
# Test default initialization
|
||||
try:
|
||||
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
|
||||
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
@@ -67,7 +67,7 @@ def test_lumina_tokenize_strategy():
|
||||
def test_lumina_text_encoding_strategy():
|
||||
# Create strategies
|
||||
try:
|
||||
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
|
||||
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
@@ -148,7 +148,7 @@ def test_lumina_text_encoder_outputs_caching_strategy():
|
||||
|
||||
# Create mock strategies and model
|
||||
try:
|
||||
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
|
||||
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
||||
except OSError as e:
|
||||
# If the tokenizer is not found (due to gated repo), we can skip the test
|
||||
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
||||
|
||||
Reference in New Issue
Block a user