mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
* fix: update extend-exclude list in _typos.toml to include configs * fix: exclude anima tests from pytest * feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE * fix: update default value for --discrete_flow_shift in anima training guide * feat: add Qwen-Image VAE * feat: simplify encode_tokens * feat: use unified attention module, add wrapper for state dict compatibility * feat: loading with dynamic fp8 optimization and LoRA support * feat: add anima minimal inference script (WIP) * format: format * feat: simplify target module selection by regular expression patterns * feat: kept caption dropout rate in cache and handle in training script * feat: update train_llm_adapter and verbose default values to string type * fix: use strategy instead of using tokenizers directly * feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock * feat: support 5d tensor in get_noisy_model_input_and_timesteps * feat: update loss calculation to support 5d tensor * fix: update argument names in anima_train_utils to align with other archtectures * feat: simplify Anima training script and update empty caption handling * feat: support LoRA format without `net.` prefix * fix: update to work fp8_scaled option * feat: add regex-based learning rates and dimensions handling in create_network * fix: improve regex matching for module selection and learning rates in LoRANetwork * fix: update logging message for regex match in LoRANetwork * fix: keep latents 4D except DiT call * feat: enhance block swap functionality for inference and training in Anima model * feat: refactor Anima training script * feat: optimize VAE processing by adjusting tensor dimensions and data types * fix: wait all block trasfer before siwtching offloader mode * feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude! * feat: support LORA for Qwen3 * feat: update Anima SAI model spec metadata handling * fix: remove unused code * feat: split CFG processing in do_sample function to reduce memory usage * feat: add VAE chunking and caching options to reduce memory usage * feat: optimize RMSNorm forward method and remove unused torch_attention_op * Update library/strategy_anima.py Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/safetensors_utils.py Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_minimal_inference.py Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_train.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/anima_train_utils.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: review with Copilot * feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet) * feat: add process_escape function to handle escape sequences in prompts * feat: enhance LoRA weight handling in model loading and add text encoder loading function * feat: improve ComfyUI conversion script with prefix constants and module name adjustments * feat: update caption dropout documentation to clarify cache regeneration requirement * feat: add clarification on learning rate adjustments * feat: add note on PyTorch version requirement to prevent NaN loss --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
346 lines
13 KiB
Python
346 lines
13 KiB
Python
from concurrent.futures import ThreadPoolExecutor
|
|
import gc
|
|
import time
|
|
from typing import Any, Optional, Union, Callable, Tuple
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
# Keep these functions here for portability, and private to avoid confusion with the ones in device_utils.py
|
|
def _clean_memory_on_device(device: torch.device):
|
|
r"""
|
|
Clean memory on the specified device, will be called from training scripts.
|
|
"""
|
|
gc.collect()
|
|
|
|
# device may "cuda" or "cuda:0", so we need to check the type of device
|
|
if device.type == "cuda":
|
|
torch.cuda.empty_cache()
|
|
if device.type == "xpu":
|
|
torch.xpu.empty_cache()
|
|
if device.type == "mps":
|
|
torch.mps.empty_cache()
|
|
|
|
|
|
def _synchronize_device(device: torch.device):
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
elif device.type == "xpu":
|
|
torch.xpu.synchronize()
|
|
elif device.type == "mps":
|
|
torch.mps.synchronize()
|
|
|
|
|
|
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
|
|
|
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
|
|
|
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
|
|
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
|
# print(module_to_cpu.__class__, module_to_cuda.__class__)
|
|
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
|
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
|
|
|
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
|
|
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
|
|
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
|
|
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
|
|
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
|
|
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
|
else:
|
|
if module_to_cuda.weight.data.device.type != device.type:
|
|
# print(
|
|
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
|
|
# )
|
|
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
|
|
|
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
|
|
|
stream = torch.Stream(device="cuda")
|
|
with torch.cuda.stream(stream):
|
|
# cuda to cpu
|
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
|
cuda_data_view.record_stream(stream)
|
|
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
|
|
|
stream.synchronize()
|
|
|
|
# cpu to cuda
|
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
|
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
|
module_to_cuda.weight.data = cuda_data_view
|
|
|
|
stream.synchronize()
|
|
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
|
|
|
|
|
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
|
"""
|
|
not tested
|
|
"""
|
|
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
|
|
|
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
|
|
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
|
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
|
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
|
|
|
# device to cpu
|
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
|
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
|
|
|
_synchronize_device(device)
|
|
|
|
# cpu to device
|
|
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
|
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
|
module_to_cuda.weight.data = cuda_data_view
|
|
|
|
_synchronize_device(device)
|
|
|
|
|
|
def weighs_to_device(layer: nn.Module, device: torch.device):
|
|
for module in layer.modules():
|
|
if hasattr(module, "weight") and module.weight is not None:
|
|
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
|
|
|
|
|
class Offloader:
|
|
"""
|
|
common offloading class
|
|
"""
|
|
|
|
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
|
self.num_blocks = num_blocks
|
|
self.blocks_to_swap = blocks_to_swap
|
|
self.device = device
|
|
self.debug = debug
|
|
|
|
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
|
self.futures = {}
|
|
self.cuda_available = device.type == "cuda"
|
|
|
|
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
|
|
if self.cuda_available:
|
|
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
|
|
else:
|
|
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
|
|
|
|
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
|
|
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
|
|
if self.debug:
|
|
start_time = time.perf_counter()
|
|
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
|
|
|
|
self.swap_weight_devices(block_to_cpu, block_to_cuda)
|
|
|
|
if self.debug:
|
|
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s")
|
|
return bidx_to_cpu, bidx_to_cuda # , event
|
|
|
|
block_to_cpu = blocks[block_idx_to_cpu]
|
|
block_to_cuda = blocks[block_idx_to_cuda]
|
|
|
|
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
|
|
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
|
|
)
|
|
|
|
def _wait_blocks_move(self, block_idx):
|
|
if block_idx not in self.futures:
|
|
return
|
|
|
|
if self.debug:
|
|
print(f"Wait for block {block_idx}")
|
|
start_time = time.perf_counter()
|
|
|
|
future = self.futures.pop(block_idx)
|
|
_, bidx_to_cuda = future.result()
|
|
|
|
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
|
|
|
|
if self.debug:
|
|
print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
|
|
|
|
|
|
# Gradient tensors
|
|
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]
|
|
|
|
|
|
class ModelOffloader(Offloader):
|
|
"""
|
|
supports forward offloading
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
blocks: Union[list[nn.Module], nn.ModuleList],
|
|
blocks_to_swap: int,
|
|
device: torch.device,
|
|
supports_backward: bool = True,
|
|
debug: bool = False,
|
|
):
|
|
super().__init__(len(blocks), blocks_to_swap, device, debug)
|
|
|
|
self.supports_backward = supports_backward
|
|
self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference
|
|
|
|
if self.supports_backward:
|
|
# register backward hooks
|
|
self.remove_handles = []
|
|
for i, block in enumerate(blocks):
|
|
hook = self.create_backward_hook(blocks, i)
|
|
if hook is not None:
|
|
handle = block.register_full_backward_hook(hook)
|
|
self.remove_handles.append(handle)
|
|
|
|
def set_forward_only(self, forward_only: bool):
|
|
# switching must wait for all pending transfers
|
|
for block_idx in list(self.futures.keys()):
|
|
self._wait_blocks_move(block_idx)
|
|
self.forward_only = forward_only
|
|
|
|
def __del__(self):
|
|
if self.supports_backward:
|
|
for handle in self.remove_handles:
|
|
handle.remove()
|
|
|
|
def create_backward_hook(
|
|
self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int
|
|
) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
|
|
# -1 for 0-based index
|
|
num_blocks_propagated = self.num_blocks - block_index - 1
|
|
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
|
waiting = block_index > 0 and block_index <= self.blocks_to_swap
|
|
|
|
if not swapping and not waiting:
|
|
return None
|
|
|
|
# create hook
|
|
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
|
|
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
|
block_idx_to_wait = block_index - 1
|
|
|
|
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
|
|
if self.debug:
|
|
print(f"Backward hook for block {block_index}")
|
|
|
|
if swapping:
|
|
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
|
if waiting:
|
|
self._wait_blocks_move(block_idx_to_wait)
|
|
return None
|
|
|
|
return backward_hook
|
|
|
|
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
|
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
|
return
|
|
|
|
if self.debug:
|
|
print(f"Prepare block devices before forward")
|
|
|
|
# wait for all pending transfers
|
|
for block_idx in list(self.futures.keys()):
|
|
self._wait_blocks_move(block_idx)
|
|
|
|
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
|
b.to(self.device)
|
|
weighs_to_device(b, self.device) # make sure weights are on device
|
|
|
|
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
|
|
b.to(self.device) # move block to device first. this makes sure that buffers (non weights) are on the device
|
|
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu
|
|
|
|
_synchronize_device(self.device)
|
|
_clean_memory_on_device(self.device)
|
|
|
|
def wait_for_block(self, block_idx: int):
|
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
|
return
|
|
self._wait_blocks_move(block_idx)
|
|
|
|
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
|
|
# check if blocks_to_swap is enabled
|
|
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
|
return
|
|
|
|
# if backward is enabled, we do not swap blocks in forward pass more than blocks_to_swap, because it should be on GPU
|
|
if not self.forward_only and block_idx >= self.blocks_to_swap:
|
|
return
|
|
|
|
block_idx_to_cpu = block_idx
|
|
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
|
|
# this works for forward-only offloading. move upstream blocks to cuda
|
|
block_idx_to_cuda = block_idx_to_cuda % self.num_blocks
|
|
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
|
|
|
|
|
# endregion
|
|
|
|
# region cpu offload utils
|
|
|
|
|
|
def to_device(x: Any, device: torch.device) -> Any:
|
|
if isinstance(x, torch.Tensor):
|
|
return x.to(device)
|
|
elif isinstance(x, list):
|
|
return [to_device(elem, device) for elem in x]
|
|
elif isinstance(x, tuple):
|
|
return tuple(to_device(elem, device) for elem in x)
|
|
elif isinstance(x, dict):
|
|
return {k: to_device(v, device) for k, v in x.items()}
|
|
else:
|
|
return x
|
|
|
|
|
|
def to_cpu(x: Any) -> Any:
|
|
"""
|
|
Recursively moves torch.Tensor objects (and containers thereof) to CPU.
|
|
|
|
Args:
|
|
x: A torch.Tensor, or a (possibly nested) list, tuple, or dict containing tensors.
|
|
|
|
Returns:
|
|
The same structure as x, with all torch.Tensor objects moved to CPU.
|
|
Non-tensor objects are returned unchanged.
|
|
"""
|
|
if isinstance(x, torch.Tensor):
|
|
return x.cpu()
|
|
elif isinstance(x, list):
|
|
return [to_cpu(elem) for elem in x]
|
|
elif isinstance(x, tuple):
|
|
return tuple(to_cpu(elem) for elem in x)
|
|
elif isinstance(x, dict):
|
|
return {k: to_cpu(v) for k, v in x.items()}
|
|
else:
|
|
return x
|
|
|
|
|
|
def create_cpu_offloading_wrapper(func: Callable, device: torch.device) -> Callable:
|
|
"""
|
|
Create a wrapper function that offloads inputs to CPU before calling the original function
|
|
and moves outputs back to the specified device.
|
|
|
|
Args:
|
|
func: The original function to wrap.
|
|
device: The device to move outputs back to.
|
|
|
|
Returns:
|
|
A wrapped function that offloads inputs to CPU and moves outputs back to the specified device.
|
|
"""
|
|
|
|
def wrapper(orig_func: Callable) -> Callable:
|
|
def custom_forward(*inputs):
|
|
nonlocal device, orig_func
|
|
cuda_inputs = to_device(inputs, device)
|
|
outputs = orig_func(*cuda_inputs)
|
|
return to_cpu(outputs)
|
|
|
|
return custom_forward
|
|
|
|
return wrapper(func)
|
|
|
|
|
|
# endregion
|