re-fix sample generation is not working in FLUX1 split mode #1647

This commit is contained in:
Kohya S
2024-09-29 00:35:29 +09:00
parent 822fe57859
commit 1a0f5b0c38
2 changed files with 3 additions and 1 deletions

View File

@@ -300,6 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.flux_lower = flux_lower
self.target_device = device
def prepare_block_swap_before_forward(self):
pass
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
self.flux_lower.to("cpu")
clean_memory_on_device(self.target_device)

View File

@@ -196,7 +196,6 @@ def sample_image_inference(
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
print([x.shape if x is not None else None for x in encoded_text_encoder_conds])
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0: