diff --git a/anima_train_network.py b/anima_train_network.py index eaad7197..ff770a9f 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -286,7 +286,9 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): t.requires_grad_(True) # Unpack text encoder conditions - prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds + prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds[ + :4 + ] # ignore caption_dropout_rate which is not needed for training step # Move to device prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype) @@ -353,7 +355,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs( *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates ) - batch["text_encoder_outputs_list"] = text_encoder_outputs_list + # Add the caption dropout rates back to the list for validation dataset (which is re-used batch items) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + [caption_dropout_rates] return super().process_batch( batch,