mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Merge pull request #2224 from rockerBOO/block-swap-fp8-scaled
Block swap and fp8 scaled, fp8 already quantized warning
This commit is contained in:
@@ -190,7 +190,7 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like
|
||||
* `--fp8_vl`
|
||||
- Use FP8 for the VLM (Qwen2.5-VL) text encoder.
|
||||
* `--text_encoder_cpu`
|
||||
- Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts.
|
||||
- Runs the text encoders on CPU to reduce VRAM usage. This is useful when VRAM is insufficient (less than 12GB). Encoding one text may take a few minutes (depending on CPU). It is highly recommended to use this option with `--cache_text_encoder_outputs_to_disk` to avoid repeated encoding every time training starts. **In addition, increasing `--num_cpu_threads_per_process` in the `accelerate launch` command, like `--num_cpu_threads_per_process=8` or `16`, can speed up encoding in some environments.**
|
||||
* `--blocks_to_swap=<integer>` **[Experimental Feature]**
|
||||
- Setting to reduce VRAM usage by swapping parts of the model (Transformer blocks) between CPU and GPU. Specify the number of blocks to swap as an integer (e.g., `18`). Larger values reduce VRAM usage but decrease training speed. Adjust according to your GPU's VRAM capacity. Can be used with `gradient_checkpointing`.
|
||||
* `--cache_text_encoder_outputs`
|
||||
|
||||
@@ -101,6 +101,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
|
||||
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
|
||||
loading_dtype = None if args.fp8_base or args.fp8_scaled else weight_dtype
|
||||
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
|
||||
@@ -125,8 +127,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# if args.split_mode:
|
||||
# model = self.prepare_split_model(model, weight_dtype, accelerator)
|
||||
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
|
||||
@@ -249,7 +249,15 @@ def sample_image_inference(
|
||||
arg_c_null = None
|
||||
|
||||
gen_args = SimpleNamespace(
|
||||
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale, fp8=args.fp8_scaled
|
||||
image_size=(height, width),
|
||||
infer_steps=sample_steps,
|
||||
flow_shift=flow_shift,
|
||||
guidance_scale=cfg_scale,
|
||||
fp8=args.fp8_scaled,
|
||||
apg_start_step_ocr=38,
|
||||
apg_start_step_general=5,
|
||||
guidance_rescale=0.0,
|
||||
guidance_rescale_apg=0.0,
|
||||
)
|
||||
|
||||
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import
|
||||
|
||||
@@ -306,11 +306,22 @@ def load_safetensors_with_fp8_optimization(
|
||||
state_dict[key] = value
|
||||
continue
|
||||
|
||||
original_dtype = value.dtype
|
||||
|
||||
if original_dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz):
|
||||
logger.warning(
|
||||
f"Skipping FP8 quantization for key {key} as it is already in FP8 format ({original_dtype}). "
|
||||
"Loading checkpoint as-is without re-quantization."
|
||||
)
|
||||
target_device = calc_device if (calc_device is not None and move_to_device) else original_device
|
||||
value = value.to(target_device)
|
||||
state_dict[key] = value
|
||||
continue
|
||||
|
||||
# Move to calculation device
|
||||
if calc_device is not None:
|
||||
value = value.to(calc_device)
|
||||
|
||||
original_dtype = value.dtype
|
||||
quantized_weight, scale_tensor = quantize_weight(
|
||||
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
|
||||
)
|
||||
|
||||
@@ -327,14 +327,17 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
|
||||
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
if support_text_encoder_caching:
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs / text encoderの出力をキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_mmap_load_safetensors",
|
||||
action="store_true",
|
||||
@@ -342,7 +345,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en
|
||||
)
|
||||
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, support_text_encoder_caching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
|
||||
if args.clip_skip is not None:
|
||||
@@ -365,7 +368,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
||||
# not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
||||
|
||||
if supportTextEncoderCaching:
|
||||
if support_text_encoder_caching:
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
args.cache_text_encoder_outputs = True
|
||||
logger.warning(
|
||||
|
||||
@@ -20,7 +20,8 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
||||
# super().assert_extra_args(args, train_dataset_group) # do not call parent because it checks reso steps with 64
|
||||
sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
if val_dataset_group is not None:
|
||||
|
||||
@@ -57,7 +57,7 @@ def convert(args):
|
||||
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None
|
||||
|
||||
# make reverse map from diffusers map
|
||||
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
|
||||
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map(19, 38)
|
||||
|
||||
# iterate over three safetensors files to reduce memory usage
|
||||
flux_sd = {}
|
||||
|
||||
Reference in New Issue
Block a user