diff --git a/fine_tune.py b/fine_tune.py index cdc005d9..0b7cc510 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -177,7 +177,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/flux_train.py b/flux_train.py index 46a8babd..91ae3af5 100644 --- a/flux_train.py +++ b/flux_train.py @@ -190,7 +190,7 @@ def train(args): ae.requires_grad_(False) ae.eval() - train_dataset_group.new_cache_latents(ae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(ae, accelerator) ae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) diff --git a/sd3_train.py b/sd3_train.py index 7290956a..ef18c32c 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -243,7 +243,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) diff --git a/sdxl_train.py b/sdxl_train.py index 9b2d1916..79a2fbb6 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -272,7 +272,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 74b3a64a..24080afb 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -209,7 +209,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/train_db.py b/train_db.py index 683b4233..4a58e27b 100644 --- a/train_db.py +++ b/train_db.py @@ -156,7 +156,7 @@ def train(args): vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4d8a3abb..77b5d717 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -378,7 +378,7 @@ class TextualInversionTrainer: vae.requires_grad_(False) vae.eval() - train_dataset_group.new_cache_latents(vae, accelerator.is_main_process) + train_dataset_group.new_cache_latents(vae, accelerator) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone()