diff --git a/docs/hunyuan_image_train_network.md b/docs/hunyuan_image_train_network.md index b2bf113d..b0e9cdd9 100644 --- a/docs/hunyuan_image_train_network.md +++ b/docs/hunyuan_image_train_network.md @@ -190,7 +190,7 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like * `--fp8_vl` - Use FP8 for the VLM (Qwen2.5-VL) text encoder. * `--text_encoder_cpu` - - Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts. + - Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts. **In addition, increasing `--num_cpu_threads_per_process` in the `accelerate launch` command, like `--num_cpu_threads_per_process=8` or `16`, can speed up encoding in some environments.** * `--blocks_to_swap=` **[Experimental Feature]** - Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`. * `--cache_text_encoder_outputs` diff --git a/flux_train_network.py b/flux_train_network.py index db61f15d..b0892671 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -101,6 +101,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models + self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 + # if the file is fp8 and we are using fp8_base, we can load it as is (fp8) loading_dtype = None if args.fp8_base or args.fp8_scaled else weight_dtype loading_device = "cpu" if self.is_swapping_blocks else accelerator.device @@ -125,8 +127,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # if args.split_mode: # model = self.prepare_split_model(model, weight_dtype, accelerator) - - self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if self.is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") diff --git a/hunyuan_image_train_network.py b/hunyuan_image_train_network.py index a67e931d..9ab351ea 100644 --- a/hunyuan_image_train_network.py +++ b/hunyuan_image_train_network.py @@ -249,7 +249,15 @@ def sample_image_inference( arg_c_null = None gen_args = SimpleNamespace( - image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale, fp8=args.fp8_scaled + image_size=(height, width), + infer_steps=sample_steps, + flow_shift=flow_shift, + guidance_scale=cfg_scale, + fp8=args.fp8_scaled, + apg_start_step_ocr=38, + apg_start_step_general=5, + guidance_rescale=0.0, + guidance_rescale_apg=0.0, ) from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import diff --git a/library/fp8_optimization_utils.py b/library/fp8_optimization_utils.py index 02f99ab6..9ea62a58 100644 --- a/library/fp8_optimization_utils.py +++ b/library/fp8_optimization_utils.py @@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization( state_dict[key] = value continue + original_dtype = value.dtype + + if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz): + logger.warning( + f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). " + "Loading checkpoint as-is without re-quantization." + ) + target_device = calc_device if (calc_device is not None and move_to_device) else original_device + value = value.to(target_device) + state_dict[key] = value + continue + # Move to calculation device if calc_device is not None: value = value.to(calc_device) - original_dtype = value.dtype quantized_weight, scale_tensor = quantize_weight( key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size ) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 7c5e6860..e559e718 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -327,14 +327,17 @@ def save_sd_model_on_epoch_end_or_stepwise( def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True): - parser.add_argument( - "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" - ) - parser.add_argument( - "--cache_text_encoder_outputs_to_disk", - action="store_true", - help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", - ) + if support_text_encoder_caching: + parser.add_argument( + "--cache_text_encoder_outputs", + action="store_true", + help="cache text encoder outputs / text encoderの出力をキャッシュする", + ) + parser.add_argument( + "--cache_text_encoder_outputs_to_disk", + action="store_true", + help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", + ) parser.add_argument( "--disable_mmap_load_safetensors", action="store_true", @@ -342,7 +345,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en ) -def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): +def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True): assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" if args.clip_skip is not None: @@ -365,7 +368,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin # not hasattr(args, "weighted_captions") or not args.weighted_captions # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません" - if supportTextEncoderCaching: + if support_text_encoder_caching: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: args.cache_text_encoder_outputs = True logger.warning( diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index be538cdd..6dec31de 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -20,7 +20,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine self.is_sdxl = True def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): - sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False) + # super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64 + sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False) train_dataset_group.verify_bucket_reso_steps(32) if val_dataset_group is not None: diff --git a/tools/convert_diffusers_to_flux.py b/tools/convert_diffusers_to_flux.py index a11093c9..9dcd8fed 100644 --- a/tools/convert_diffusers_to_flux.py +++ b/tools/convert_diffusers_to_flux.py @@ -57,7 +57,7 @@ def convert(args): save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None # make reverse map from diffusers map - diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map() + diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map(19, 38) # iterate over three safetensors files to reduce memory usage flux_sd = {}