mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-07 14:08:49 +00:00
523 lines
20 KiB
Python
523 lines
20 KiB
Python
import logging
|
||
import sys
|
||
import threading
|
||
from typing import *
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from torchvision import transforms
|
||
from diffusers import EulerAncestralDiscreteScheduler
|
||
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
||
import cv2
|
||
from PIL import Image
|
||
import numpy as np
|
||
|
||
|
||
def fire_in_thread(f, *args, **kwargs):
|
||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||
|
||
|
||
# region Logging
|
||
|
||
|
||
def add_logging_arguments(parser):
|
||
parser.add_argument(
|
||
"--console_log_level",
|
||
type=str,
|
||
default=None,
|
||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
|
||
)
|
||
parser.add_argument(
|
||
"--console_log_file",
|
||
type=str,
|
||
default=None,
|
||
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
|
||
)
|
||
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
|
||
|
||
|
||
def setup_logging(args=None, log_level=None, reset=False):
|
||
if logging.root.handlers:
|
||
if reset:
|
||
# remove all handlers
|
||
for handler in logging.root.handlers[:]:
|
||
logging.root.removeHandler(handler)
|
||
else:
|
||
return
|
||
|
||
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
|
||
if log_level is None and args is not None:
|
||
log_level = args.console_log_level
|
||
if log_level is None:
|
||
log_level = "INFO"
|
||
log_level = getattr(logging, log_level)
|
||
|
||
msg_init = None
|
||
if args is not None and args.console_log_file:
|
||
handler = logging.FileHandler(args.console_log_file, mode="w")
|
||
else:
|
||
handler = None
|
||
if not args or not args.console_log_simple:
|
||
try:
|
||
from rich.logging import RichHandler
|
||
from rich.console import Console
|
||
from rich.logging import RichHandler
|
||
|
||
handler = RichHandler(console=Console(stderr=True))
|
||
except ImportError:
|
||
# print("rich is not installed, using basic logging")
|
||
msg_init = "rich is not installed, using basic logging"
|
||
|
||
if handler is None:
|
||
handler = logging.StreamHandler(sys.stdout) # same as print
|
||
handler.propagate = False
|
||
|
||
formatter = logging.Formatter(
|
||
fmt="%(message)s",
|
||
datefmt="%Y-%m-%d %H:%M:%S",
|
||
)
|
||
handler.setFormatter(formatter)
|
||
logging.root.setLevel(log_level)
|
||
logging.root.addHandler(handler)
|
||
|
||
if msg_init is not None:
|
||
logger = logging.getLogger(__name__)
|
||
logger.info(msg_init)
|
||
|
||
|
||
setup_logging()
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# endregion
|
||
|
||
# region PyTorch utils
|
||
|
||
|
||
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||
|
||
weight_swap_jobs = []
|
||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||
|
||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||
|
||
stream = torch.cuda.Stream()
|
||
with torch.cuda.stream(stream):
|
||
# cuda to cpu
|
||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||
cuda_data_view.record_stream(stream)
|
||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||
|
||
stream.synchronize()
|
||
|
||
# cpu to cuda
|
||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||
module_to_cuda.weight.data = cuda_data_view
|
||
|
||
stream.synchronize()
|
||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||
|
||
|
||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||
for module in layer.modules():
|
||
if hasattr(module, "weight") and module.weight is not None:
|
||
module.weight.data = module.weight.data.to(device, non_blocking=True)
|
||
|
||
|
||
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||
"""
|
||
Convert a string to a torch.dtype
|
||
|
||
Args:
|
||
s: string representation of the dtype
|
||
default_dtype: default dtype to return if s is None
|
||
|
||
Returns:
|
||
torch.dtype: the corresponding torch.dtype
|
||
|
||
Raises:
|
||
ValueError: if the dtype is not supported
|
||
|
||
Examples:
|
||
>>> str_to_dtype("float32")
|
||
torch.float32
|
||
>>> str_to_dtype("fp32")
|
||
torch.float32
|
||
>>> str_to_dtype("float16")
|
||
torch.float16
|
||
>>> str_to_dtype("fp16")
|
||
torch.float16
|
||
>>> str_to_dtype("bfloat16")
|
||
torch.bfloat16
|
||
>>> str_to_dtype("bf16")
|
||
torch.bfloat16
|
||
>>> str_to_dtype("fp8")
|
||
torch.float8_e4m3fn
|
||
>>> str_to_dtype("fp8_e4m3fn")
|
||
torch.float8_e4m3fn
|
||
>>> str_to_dtype("fp8_e4m3fnuz")
|
||
torch.float8_e4m3fnuz
|
||
>>> str_to_dtype("fp8_e5m2")
|
||
torch.float8_e5m2
|
||
>>> str_to_dtype("fp8_e5m2fnuz")
|
||
torch.float8_e5m2fnuz
|
||
"""
|
||
if s is None:
|
||
return default_dtype
|
||
if s in ["bf16", "bfloat16"]:
|
||
return torch.bfloat16
|
||
elif s in ["fp16", "float16"]:
|
||
return torch.float16
|
||
elif s in ["fp32", "float32", "float"]:
|
||
return torch.float32
|
||
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
|
||
return torch.float8_e4m3fn
|
||
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
|
||
return torch.float8_e4m3fnuz
|
||
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
|
||
return torch.float8_e5m2
|
||
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
|
||
return torch.float8_e5m2fnuz
|
||
elif s in ["fp8", "float8"]:
|
||
return torch.float8_e4m3fn # default fp8
|
||
else:
|
||
raise ValueError(f"Unsupported dtype: {s}")
|
||
|
||
|
||
# endregion
|
||
|
||
# region Image utils
|
||
|
||
|
||
def pil_resize(image, size, interpolation):
|
||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||
|
||
if has_alpha:
|
||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
||
else:
|
||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||
|
||
resized_pil = pil_image.resize(size, resample=interpolation)
|
||
|
||
# Convert back to cv2 format
|
||
if has_alpha:
|
||
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
|
||
else:
|
||
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
|
||
|
||
return resized_cv2
|
||
|
||
|
||
def resize_image(
|
||
image: np.ndarray,
|
||
width: int,
|
||
height: int,
|
||
resized_width: int,
|
||
resized_height: int,
|
||
resize_interpolation: Optional[str] = None,
|
||
):
|
||
"""
|
||
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
|
||
|
||
Args:
|
||
image: numpy.ndarray
|
||
width: int Original image width
|
||
height: int Original image height
|
||
resized_width: int Resized image width
|
||
resized_height: int Resized image height
|
||
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
|
||
|
||
Returns:
|
||
image
|
||
"""
|
||
|
||
# Ensure all size parameters are actual integers
|
||
width = int(width)
|
||
height = int(height)
|
||
resized_width = int(resized_width)
|
||
resized_height = int(resized_height)
|
||
|
||
if resize_interpolation is None:
|
||
if width >= resized_width and height >= resized_height:
|
||
resize_interpolation = "area"
|
||
else:
|
||
resize_interpolation = "lanczos"
|
||
|
||
# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
|
||
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]
|
||
|
||
resized_size = (resized_width, resized_height)
|
||
if use_pil:
|
||
interpolation = get_pil_interpolation(resize_interpolation)
|
||
image = pil_resize(image, resized_size, interpolation=interpolation)
|
||
logger.debug(f"resize image using {resize_interpolation} (PIL)")
|
||
else:
|
||
interpolation = get_cv2_interpolation(resize_interpolation)
|
||
image = cv2.resize(image, resized_size, interpolation=interpolation)
|
||
logger.debug(f"resize image using {resize_interpolation} (cv2)")
|
||
|
||
return image
|
||
|
||
|
||
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||
"""
|
||
Convert interpolation value to cv2 interpolation integer
|
||
|
||
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
|
||
"""
|
||
if interpolation is None:
|
||
return None
|
||
|
||
if interpolation == "lanczos" or interpolation == "lanczos4":
|
||
# Lanczos interpolation over 8x8 neighborhood
|
||
return cv2.INTER_LANCZOS4
|
||
elif interpolation == "nearest":
|
||
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||
return cv2.INTER_NEAREST_EXACT
|
||
elif interpolation == "bilinear" or interpolation == "linear":
|
||
# bilinear interpolation
|
||
return cv2.INTER_LINEAR
|
||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||
# bicubic interpolation
|
||
return cv2.INTER_CUBIC
|
||
elif interpolation == "area":
|
||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||
return cv2.INTER_AREA
|
||
elif interpolation == "box":
|
||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||
return cv2.INTER_AREA
|
||
else:
|
||
return None
|
||
|
||
|
||
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
|
||
"""
|
||
Convert interpolation value to PIL interpolation
|
||
|
||
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
|
||
"""
|
||
if interpolation is None:
|
||
return None
|
||
|
||
if interpolation == "lanczos":
|
||
return Image.Resampling.LANCZOS
|
||
elif interpolation == "nearest":
|
||
# Pick one nearest pixel from the input image. Ignore all other input pixels.
|
||
return Image.Resampling.NEAREST
|
||
elif interpolation == "bilinear" or interpolation == "linear":
|
||
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
|
||
return Image.Resampling.BILINEAR
|
||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
|
||
return Image.Resampling.BICUBIC
|
||
elif interpolation == "area":
|
||
# Image.Resampling.BOX may be more appropriate if upscaling
|
||
# Area interpolation is related to cv2.INTER_AREA
|
||
# Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX.
|
||
return Image.Resampling.HAMMING
|
||
elif interpolation == "box":
|
||
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
|
||
return Image.Resampling.BOX
|
||
else:
|
||
return None
|
||
|
||
|
||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||
"""
|
||
Check if a interpolation function is supported
|
||
"""
|
||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||
|
||
|
||
# endregion
|
||
|
||
# TODO make inf_utils.py
|
||
# region Gradual Latent hires fix
|
||
|
||
|
||
class GradualLatent:
|
||
def __init__(
|
||
self,
|
||
ratio,
|
||
start_timesteps,
|
||
every_n_steps,
|
||
ratio_step,
|
||
s_noise=1.0,
|
||
gaussian_blur_ksize=None,
|
||
gaussian_blur_sigma=0.5,
|
||
gaussian_blur_strength=0.5,
|
||
unsharp_target_x=True,
|
||
):
|
||
self.ratio = ratio
|
||
self.start_timesteps = start_timesteps
|
||
self.every_n_steps = every_n_steps
|
||
self.ratio_step = ratio_step
|
||
self.s_noise = s_noise
|
||
self.gaussian_blur_ksize = gaussian_blur_ksize
|
||
self.gaussian_blur_sigma = gaussian_blur_sigma
|
||
self.gaussian_blur_strength = gaussian_blur_strength
|
||
self.unsharp_target_x = unsharp_target_x
|
||
|
||
def __str__(self) -> str:
|
||
return (
|
||
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
|
||
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
|
||
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
|
||
+ f"unsharp_target_x={self.unsharp_target_x})"
|
||
)
|
||
|
||
def apply_unshark_mask(self, x: torch.Tensor):
|
||
if self.gaussian_blur_ksize is None:
|
||
return x
|
||
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
|
||
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
|
||
mask = (x - blurred) * self.gaussian_blur_strength
|
||
sharpened = x + mask
|
||
return sharpened
|
||
|
||
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
|
||
org_dtype = x.dtype
|
||
if org_dtype == torch.bfloat16:
|
||
x = x.float()
|
||
|
||
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
|
||
|
||
# apply unsharp mask / アンシャープマスクを適用する
|
||
if unsharp and self.gaussian_blur_ksize:
|
||
x = self.apply_unshark_mask(x)
|
||
|
||
return x
|
||
|
||
|
||
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.resized_size = None
|
||
self.gradual_latent = None
|
||
|
||
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
|
||
self.resized_size = size
|
||
self.gradual_latent = gradual_latent
|
||
|
||
def step(
|
||
self,
|
||
model_output: torch.FloatTensor,
|
||
timestep: Union[float, torch.FloatTensor],
|
||
sample: torch.FloatTensor,
|
||
generator: Optional[torch.Generator] = None,
|
||
return_dict: bool = True,
|
||
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
||
"""
|
||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||
process from the learned model outputs (most often the predicted noise).
|
||
|
||
Args:
|
||
model_output (`torch.FloatTensor`):
|
||
The direct output from learned diffusion model.
|
||
timestep (`float`):
|
||
The current discrete timestep in the diffusion chain.
|
||
sample (`torch.FloatTensor`):
|
||
A current instance of a sample created by the diffusion process.
|
||
generator (`torch.Generator`, *optional*):
|
||
A random number generator.
|
||
return_dict (`bool`):
|
||
Whether or not to return a
|
||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
||
|
||
Returns:
|
||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
||
If return_dict is `True`,
|
||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
||
otherwise a tuple is returned where the first element is the sample tensor.
|
||
|
||
"""
|
||
|
||
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
||
raise ValueError(
|
||
(
|
||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||
" one of the `scheduler.timesteps` as a timestep."
|
||
),
|
||
)
|
||
|
||
if not self.is_scale_input_called:
|
||
# logger.warning(
|
||
print(
|
||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||
"See `StableDiffusionPipeline` for a usage example."
|
||
)
|
||
|
||
if self.step_index is None:
|
||
self._init_step_index(timestep)
|
||
|
||
sigma = self.sigmas[self.step_index]
|
||
|
||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||
if self.config.prediction_type == "epsilon":
|
||
pred_original_sample = sample - sigma * model_output
|
||
elif self.config.prediction_type == "v_prediction":
|
||
# * c_out + input * c_skip
|
||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||
elif self.config.prediction_type == "sample":
|
||
raise NotImplementedError("prediction_type not implemented yet: sample")
|
||
else:
|
||
raise ValueError(
|
||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||
)
|
||
|
||
sigma_from = self.sigmas[self.step_index]
|
||
sigma_to = self.sigmas[self.step_index + 1]
|
||
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||
|
||
# 2. Convert to an ODE derivative
|
||
derivative = (sample - pred_original_sample) / sigma
|
||
|
||
dt = sigma_down - sigma
|
||
|
||
device = model_output.device
|
||
if self.resized_size is None:
|
||
prev_sample = sample + derivative * dt
|
||
|
||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
|
||
)
|
||
s_noise = 1.0
|
||
else:
|
||
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
|
||
s_noise = self.gradual_latent.s_noise
|
||
|
||
if self.gradual_latent.unsharp_target_x:
|
||
prev_sample = sample + derivative * dt
|
||
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
|
||
else:
|
||
sample = self.gradual_latent.interpolate(sample, self.resized_size)
|
||
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
|
||
prev_sample = sample + derivative * dt
|
||
|
||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
|
||
dtype=model_output.dtype,
|
||
device=device,
|
||
generator=generator,
|
||
)
|
||
|
||
prev_sample = prev_sample + noise * sigma_up * s_noise
|
||
|
||
# upon completion increase step index by one
|
||
self._step_index += 1
|
||
|
||
if not return_dict:
|
||
return (prev_sample,)
|
||
|
||
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||
|
||
|
||
# endregion
|