diff --git a/flux_train_network.py b/flux_train_network.py index 704c4d32..679db62b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -445,6 +445,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) + unet.prepare_block_swap_before_forward() with torch.no_grad(): model_pred_prior = call_dit( img=packed_noisy_model_input[diff_output_pr_indices],