mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
add clean_memory_on_device and use it from training
This commit is contained in:
@@ -10,7 +10,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
@@ -156,7 +156,7 @@ def train(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,21 @@ def clean_memory():
|
|||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def clean_memory_on_device(device: torch.device):
|
||||||
|
r"""
|
||||||
|
Clean memory on the specified device, will be called from training scripts.
|
||||||
|
"""
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# device may "cuda" or "cuda:0", so we need to check the type of device
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if device.type == "xpu":
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
if device.type == "mps":
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
@functools.lru_cache(maxsize=None)
|
||||||
def get_preferred_device() -> torch.device:
|
def get_preferred_device() -> torch.device:
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
@@ -50,7 +50,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
unet.to(accelerator.device)
|
unet.to(accelerator.device)
|
||||||
vae.to(accelerator.device)
|
vae.to(accelerator.device)
|
||||||
|
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
@@ -2285,7 +2285,7 @@ def cache_batch_latents(
|
|||||||
info.latents_flipped = flipped_latent
|
info.latents_flipped = flipped_latent
|
||||||
|
|
||||||
if not HIGH_VRAM:
|
if not HIGH_VRAM:
|
||||||
clean_memory()
|
clean_memory_on_device(vae.device)
|
||||||
|
|
||||||
|
|
||||||
def cache_batch_text_encoder_outputs(
|
def cache_batch_text_encoder_outputs(
|
||||||
@@ -4026,7 +4026,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
|||||||
unet.to(accelerator.device)
|
unet.to(accelerator.device)
|
||||||
vae.to(accelerator.device)
|
vae.to(accelerator.device)
|
||||||
|
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||||
@@ -4695,7 +4695,7 @@ def sample_images_common(
|
|||||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||||
|
|
||||||
org_vae_device = vae.device # CPUにいるはず
|
org_vae_device = vae.device # CPUにいるはず
|
||||||
vae.to(distributed_state.device)
|
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||||
|
|
||||||
# unwrap unet and text_encoder(s)
|
# unwrap unet and text_encoder(s)
|
||||||
unet = accelerator.unwrap_model(unet)
|
unet = accelerator.unwrap_model(unet)
|
||||||
@@ -4752,7 +4752,11 @@ def sample_images_common(
|
|||||||
|
|
||||||
# save random state to restore later
|
# save random state to restore later
|
||||||
rng_state = torch.get_rng_state()
|
rng_state = torch.get_rng_state()
|
||||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support
|
cuda_rng_state = None
|
||||||
|
try:
|
||||||
|
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
if distributed_state.num_processes <= 1:
|
if distributed_state.num_processes <= 1:
|
||||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||||
@@ -4774,8 +4778,10 @@ def sample_images_common(
|
|||||||
# clear pipeline and cache to reduce vram usage
|
# clear pipeline and cache to reduce vram usage
|
||||||
del pipeline
|
del pipeline
|
||||||
|
|
||||||
with torch.cuda.device(torch.cuda.current_device()):
|
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
|
||||||
torch.cuda.empty_cache()
|
# with torch.cuda.device(torch.cuda.current_device()):
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
if cuda_rng_state is not None:
|
if cuda_rng_state is not None:
|
||||||
@@ -4870,10 +4876,6 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
|
|||||||
pass
|
pass
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# # clear pipeline and cache to reduce vram usage
|
|
||||||
# del pipeline
|
|
||||||
# torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# region 前処理用
|
# region 前処理用
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
@@ -250,7 +250,7 @@ def train(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -403,7 +403,7 @@ def train(args):
|
|||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
text_encoder1.to("cpu", dtype=torch.float32)
|
text_encoder1.to("cpu", dtype=torch.float32)
|
||||||
text_encoder2.to("cpu", dtype=torch.float32)
|
text_encoder2.to("cpu", dtype=torch.float32)
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
else:
|
else:
|
||||||
# make sure Text Encoders are on GPU
|
# make sure Text Encoders are on GPU
|
||||||
text_encoder1.to(accelerator.device)
|
text_encoder1.to(accelerator.device)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@@ -162,7 +162,7 @@ def train(args):
|
|||||||
accelerator.is_main_process,
|
accelerator.is_main_process,
|
||||||
)
|
)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -287,7 +287,7 @@ def train(args):
|
|||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
text_encoder1.to("cpu", dtype=torch.float32)
|
text_encoder1.to("cpu", dtype=torch.float32)
|
||||||
text_encoder2.to("cpu", dtype=torch.float32)
|
text_encoder2.to("cpu", dtype=torch.float32)
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
else:
|
else:
|
||||||
# make sure Text Encoders are on GPU
|
# make sure Text Encoders are on GPU
|
||||||
text_encoder1.to(accelerator.device)
|
text_encoder1.to(accelerator.device)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@@ -161,7 +161,7 @@ def train(args):
|
|||||||
accelerator.is_main_process,
|
accelerator.is_main_process,
|
||||||
)
|
)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -260,7 +260,7 @@ def train(args):
|
|||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
text_encoder1.to("cpu", dtype=torch.float32)
|
text_encoder1.to("cpu", dtype=torch.float32)
|
||||||
text_encoder2.to("cpu", dtype=torch.float32)
|
text_encoder2.to("cpu", dtype=torch.float32)
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
else:
|
else:
|
||||||
# make sure Text Encoders are on GPU
|
# make sure Text Encoders are on GPU
|
||||||
text_encoder1.to(accelerator.device)
|
text_encoder1.to(accelerator.device)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||||
@@ -64,7 +64,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
org_unet_device = unet.device
|
org_unet_device = unet.device
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
unet.to("cpu")
|
unet.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -79,7 +79,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
if not args.lowram:
|
if not args.lowram:
|
||||||
print("move vae and unet back to original device")
|
print("move vae and unet back to original device")
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@@ -217,8 +217,8 @@ def train(args):
|
|||||||
accelerator.is_main_process,
|
accelerator.is_main_process,
|
||||||
)
|
)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
@@ -136,7 +136,7 @@ def train(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@@ -265,7 +265,7 @@ class NetworkTrainer:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
@@ -361,7 +361,7 @@ class TextualInversionTrainer:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from multiprocessing import Value
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
@@ -284,7 +284,7 @@ def train(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory()
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user