diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index 87dc9a19..47d6d30b 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -48,7 +48,7 @@ def generate_image( steps: int, guidance_scale: float, negative_prompt: Optional[str], - args, + args: argparse.Namespace, cfg_trunc_ratio: float = 0.25, renorm_cfg: float = 1.0, ): @@ -88,7 +88,9 @@ def generate_image( with torch.no_grad(): gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks) - tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True) + tokens_and_masks = tokenize_strategy.tokenize( + negative_prompt, is_negative=True and not args.add_system_prompt_to_negative_prompt + ) with torch.no_grad(): neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks) @@ -215,6 +217,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')") parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM") parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model") + parser.add_argument("--add_system_prompt_to_negative_prompt", action="store_true", help="Add system prompt to negative prompt") parser.add_argument( "--gemma2_max_token_length", type=int, @@ -231,7 +234,7 @@ def setup_parser() -> argparse.ArgumentParser: "--cfg_trunc_ratio", type=float, default=0.25, - help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25% of timesteps will be guided.", + help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25%% of timesteps will be guided.", ) parser.add_argument( "--renorm_cfg",