add truncation when > max_length

This commit is contained in:
sdbds
2025-02-26 01:00:35 +08:00
parent fc772affbe
commit 5f9047c8cf
2 changed files with 1 additions and 1 deletions

View File

@@ -320,7 +320,6 @@ def sample_image_inference(
# Load sample prompts from Gemma 2
if gemma2_model is not None:
logger.info(f"Encoding prompt with Gemma2: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)

View File

@@ -54,6 +54,7 @@ class LuminaTokenizeStrategy(TokenizeStrategy):
max_length=self.max_length,
return_tensors="pt",
padding="max_length",
truncation=True,
pad_to_multiple_of=8,
)
return (encodings.input_ids, encodings.attention_mask)