mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
feat: implement block swapping for FLUX.1 LoRA (WIP)
This commit is contained in:
@@ -519,7 +519,7 @@ def train(args):
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
|
||||
if is_swapping_blocks:
|
||||
if False: # is_swapping_blocks:
|
||||
import library.custom_offloading_utils as custom_offloading_utils
|
||||
|
||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||
|
||||
@@ -25,6 +25,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
@@ -78,6 +79,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
if args.split_mode:
|
||||
model = self.prepare_split_model(model, weight_dtype, accelerator)
|
||||
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
clip_l.eval()
|
||||
|
||||
@@ -285,6 +292,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
|
||||
if not args.split_mode:
|
||||
if self.is_swapping_blocks:
|
||||
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
|
||||
)
|
||||
@@ -539,6 +548,19 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
flux: flux_models.Flux = unet
|
||||
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
|
||||
return flux
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
@@ -550,6 +572,17 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
||||
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[EXPERIMENTAL] "
|
||||
"Sets the number of blocks to swap during the forward and backward passes."
|
||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
||||
" / 順伝播および逆伝播中にスワップするブロックの数を設定します。"
|
||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -183,9 +183,47 @@ class ModelOffloader(Offloader):
|
||||
supports forward offloading
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
|
||||
super().__init__(num_blocks, blocks_to_swap, device, debug)
|
||||
|
||||
# register backward hooks
|
||||
self.remove_handles = []
|
||||
for i, block in enumerate(blocks):
|
||||
hook = self.create_backward_hook(blocks, i)
|
||||
if hook is not None:
|
||||
handle = block.register_full_backward_hook(hook)
|
||||
self.remove_handles.append(handle)
|
||||
|
||||
def __del__(self):
|
||||
for handle in self.remove_handles:
|
||||
handle.remove()
|
||||
|
||||
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
|
||||
# -1 for 0-based index
|
||||
num_blocks_propagated = self.num_blocks - block_index - 1
|
||||
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
|
||||
waiting = block_index > 0 and block_index <= self.blocks_to_swap
|
||||
|
||||
if not swapping and not waiting:
|
||||
return None
|
||||
|
||||
# create hook
|
||||
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
|
||||
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
|
||||
block_idx_to_wait = block_index - 1
|
||||
|
||||
def backward_hook(module, grad_input, grad_output):
|
||||
if self.debug:
|
||||
print(f"Backward hook for block {block_index}")
|
||||
|
||||
if swapping:
|
||||
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
|
||||
if waiting:
|
||||
self._wait_blocks_move(block_idx_to_wait)
|
||||
return None
|
||||
|
||||
return backward_hook
|
||||
|
||||
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
@@ -970,8 +970,12 @@ class Flux(nn.Module):
|
||||
double_blocks_to_swap = num_blocks // 2
|
||||
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device)
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from accelerate import Accelerator
|
||||
from diffusers import DDPMScheduler
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||
|
||||
@@ -272,6 +273,11 @@ class NetworkTrainer:
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
return accelerator.prepare(unet)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
pass
|
||||
|
||||
@@ -627,7 +633,8 @@ class NetworkTrainer:
|
||||
training_model = ds_model
|
||||
else:
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
# default implementation is: unet = accelerator.prepare(unet)
|
||||
unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here
|
||||
else:
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_text_encoder:
|
||||
|
||||
Reference in New Issue
Block a user