add clean_memory_on_device and use it from training

This commit is contained in:
Kohya S
2024-02-12 11:10:52 +09:00
parent 75ecb047e2
commit e24d9606a2
13 changed files with 55 additions and 38 deletions

View File

@@ -10,7 +10,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -156,7 +156,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

View File

@@ -31,6 +31,21 @@ def clean_memory():
torch.mps.empty_cache()
def clean_memory_on_device(device: torch.device):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""

View File

@@ -4,7 +4,7 @@ import os
from typing import Optional
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate import init_empty_weights
@@ -50,7 +50,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

View File

@@ -32,7 +32,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -2285,7 +2285,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent
if not HIGH_VRAM:
clean_memory()
clean_memory_on_device(vae.device)
def cache_batch_text_encoder_outputs(
@@ -4026,7 +4026,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device)
vae.to(accelerator.device)
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4695,7 +4695,7 @@ def sample_images_common(
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
org_vae_device = vae.device # CPUにいるはず
vae.to(distributed_state.device)
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
@@ -4752,7 +4752,11 @@ def sample_images_common(
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
@@ -4774,8 +4778,10 @@ def sample_images_common(
# clear pipeline and cache to reduce vram usage
del pipeline
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
# with torch.cuda.device(torch.cuda.current_device()):
# torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
@@ -4870,10 +4876,6 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
pass
# endregion
# # clear pipeline and cache to reduce vram usage
# del pipeline
# torch.cuda.empty_cache()
# region 前処理用

View File

@@ -10,7 +10,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -250,7 +250,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -403,7 +403,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)

View File

@@ -14,7 +14,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -162,7 +162,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -287,7 +287,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)

View File

@@ -11,7 +11,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -161,7 +161,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -260,7 +260,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)

View File

@@ -1,7 +1,7 @@
import argparse
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util
@@ -64,7 +64,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
@@ -79,7 +79,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
clean_memory()
clean_memory_on_device(accelerator.device)
if not args.lowram:
print("move vae and unet back to original device")

View File

@@ -11,7 +11,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -217,8 +217,8 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:

View File

@@ -11,7 +11,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -136,7 +136,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

View File

@@ -12,7 +12,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -265,7 +265,7 @@ class NetworkTrainer:
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

View File

@@ -7,7 +7,7 @@ import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -361,7 +361,7 @@ class TextualInversionTrainer:
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

View File

@@ -8,7 +8,7 @@ from multiprocessing import Value
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -284,7 +284,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()