feat: implement block swapping for FLUX.1 LoRA (WIP)

This commit is contained in:
Kohya S
2024-11-12 08:49:05 +09:00
parent 7feaae5f06
commit cde90b8903
5 changed files with 87 additions and 5 deletions

View File

@@ -519,7 +519,7 @@ def train(args):
num_parameters_per_group[opt_idx] += 1
# add hooks for block swapping: this hook is called after fused_backward_pass hook or blockwise_fused_optimizers hook
if is_swapping_blocks:
if False: # is_swapping_blocks:
import library.custom_offloading_utils as custom_offloading_utils
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)

View File

@@ -25,6 +25,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
@@ -78,6 +79,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if args.split_mode:
model = self.prepare_split_model(model, weight_dtype, accelerator)
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()
@@ -285,6 +292,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
if not args.split_mode:
if self.is_swapping_blocks:
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
@@ -539,6 +548,19 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device
flux: flux_models.Flux = unet
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
return flux
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
@@ -550,6 +572,17 @@ def setup_parser() -> argparse.ArgumentParser:
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
)
parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロックの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
return parser

View File

@@ -183,9 +183,47 @@ class ModelOffloader(Offloader):
supports forward offloading
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
for i, block in enumerate(blocks):
hook = self.create_backward_hook(blocks, i)
if hook is not None:
handle = block.register_full_backward_hook(hook)
self.remove_handles.append(handle)
def __del__(self):
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
waiting = block_index > 0 and block_index <= self.blocks_to_swap
if not swapping and not waiting:
return None
# create hook
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
def backward_hook(module, grad_input, grad_output):
if self.debug:
print(f"Backward hook for block {block_index}")
if swapping:
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
if waiting:
self._wait_blocks_move(block_idx_to_wait)
return None
return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return

View File

@@ -970,8 +970,12 @@ class Flux(nn.Module):
double_blocks_to_swap = num_blocks // 2
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
self.offloader_double = custom_offloading_utils.ModelOffloader(self.num_double_blocks, double_blocks_to_swap, device)
self.offloader_single = custom_offloading_utils.ModelOffloader(self.num_single_blocks, single_blocks_to_swap, device)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device #, debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device #, debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
)

View File

@@ -18,6 +18,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from accelerate import Accelerator
from diffusers import DDPMScheduler
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
@@ -272,6 +273,11 @@ class NetworkTrainer:
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
return accelerator.prepare(unet)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass
@@ -627,7 +633,8 @@ class NetworkTrainer:
training_model = ds_model
else:
if train_unet:
unet = accelerator.prepare(unet)
# default implementation is: unet = accelerator.prepare(unet)
unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder: