Merge pull request #2224 from rockerBOO/block-swap-fp8-scaled

Block swap and fp8 scaled, fp8 already quantized warning
This commit is contained in:
Kohya S.
2026-01-18 14:35:32 +09:00
committed by GitHub
7 changed files with 40 additions and 17 deletions

View File

@@ -190,7 +190,7 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like
* `--fp8_vl`
- Use FP8 for the VLM (Qwen2.5-VL) text encoder.
* `--text_encoder_cpu`
- Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts.
- Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts. **In addition, increasing `--num_cpu_threads_per_process` in the `accelerate launch` command, like `--num_cpu_threads_per_process=8` or `16`, can speed up encoding in some environments.**
* `--blocks_to_swap=<integer>` **[Experimental Feature]**
- Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`.
* `--cache_text_encoder_outputs`

View File

@@ -101,6 +101,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base or args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
@@ -125,8 +127,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# if args.split_mode:
# model = self.prepare_split_model(model, weight_dtype, accelerator)
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")

View File

@@ -249,7 +249,15 @@ def sample_image_inference(
arg_c_null = None
gen_args = SimpleNamespace(
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale, fp8=args.fp8_scaled
image_size=(height, width),
infer_steps=sample_steps,
flow_shift=flow_shift,
guidance_scale=cfg_scale,
fp8=args.fp8_scaled,
apg_start_step_ocr=38,
apg_start_step_general=5,
guidance_rescale=0.0,
guidance_rescale_apg=0.0,
)
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import

View File

@@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization(
state_dict[key] = value
continue
original_dtype = value.dtype
if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz):
logger.warning(
f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). "
"Loading checkpoint as-is without re-quantization."
)
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
value = value.to(target_device)
state_dict[key] = value
continue
# Move to calculation device
if calc_device is not None:
value = value.to(calc_device)
original_dtype = value.dtype
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)

View File

@@ -327,14 +327,17 @@ def save_sd_model_on_epoch_end_or_stepwise(
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
if support_text_encoder_caching:
parser.add_argument(
"--cache_text_encoder_outputs",
action="store_true",
help="cache text encoder outputs / text encoderの出力をキャッシュする",
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
@@ -342,7 +345,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.clip_skip is not None:
@@ -365,7 +368,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
# not hasattr(args, "weighted_captions") or not args.weighted_captions
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching:
if support_text_encoder_caching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
logger.warning(

View File

@@ -20,7 +20,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
# super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:

View File

@@ -57,7 +57,7 @@ def convert(args):
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None
# make reverse map from diffusers map
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map(19, 38)
# iterate over three safetensors files to reduce memory usage
flux_sd = {}