import logging import sys import threading from typing import * import torch import torch.nn as nn from torchvision import transforms from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput import cv2 from PIL import Image import numpy as np def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() # region Logging def add_logging_arguments(parser): parser.add_argument( "--console_log_level", type=str, default=None, choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO", ) parser.add_argument( "--console_log_file", type=str, default=None, help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する", ) parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力") def setup_logging(args=None, log_level=None, reset=False): if logging.root.handlers: if reset: # remove all handlers for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) else: return # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO if log_level is None and args is not None: log_level = args.console_log_level if log_level is None: log_level = "INFO" log_level = getattr(logging, log_level) msg_init = None if args is not None and args.console_log_file: handler = logging.FileHandler(args.console_log_file, mode="w") else: handler = None if not args or not args.console_log_simple: try: from rich.logging import RichHandler from rich.console import Console from rich.logging import RichHandler handler = RichHandler(console=Console(stderr=True)) except ImportError: # print("rich is not installed, using basic logging") msg_init = "rich is not installed, using basic logging" if handler is None: handler = logging.StreamHandler(sys.stdout) # same as print handler.propagate = False formatter = logging.Formatter( fmt="%(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler.setFormatter(formatter) logging.root.setLevel(log_level) logging.root.addHandler(handler) if msg_init is not None: logger = logging.getLogger(__name__) logger.info(msg_init) setup_logging() logger = logging.getLogger(__name__) # endregion # region PyTorch utils def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module): assert layer_to_cpu.__class__ == layer_to_cuda.__class__ weight_swap_jobs = [] for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()): if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None: weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data)) torch.cuda.current_stream().synchronize() # this prevents the illegal loss value stream = torch.cuda.Stream() with torch.cuda.stream(stream): # cuda to cpu for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.record_stream(stream) module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) stream.synchronize() # cpu to cuda for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) module_to_cuda.weight.data = cuda_data_view stream.synchronize() torch.cuda.current_stream().synchronize() # this prevents the illegal loss value def weighs_to_device(layer: nn.Module, device: torch.device): for module in layer.modules(): if hasattr(module, "weight") and module.weight is not None: module.weight.data = module.weight.data.to(device, non_blocking=True) def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: """ Convert a string to a torch.dtype Args: s: string representation of the dtype default_dtype: default dtype to return if s is None Returns: torch.dtype: the corresponding torch.dtype Raises: ValueError: if the dtype is not supported Examples: >>> str_to_dtype("float32") torch.float32 >>> str_to_dtype("fp32") torch.float32 >>> str_to_dtype("float16") torch.float16 >>> str_to_dtype("fp16") torch.float16 >>> str_to_dtype("bfloat16") torch.bfloat16 >>> str_to_dtype("bf16") torch.bfloat16 >>> str_to_dtype("fp8") torch.float8_e4m3fn >>> str_to_dtype("fp8_e4m3fn") torch.float8_e4m3fn >>> str_to_dtype("fp8_e4m3fnuz") torch.float8_e4m3fnuz >>> str_to_dtype("fp8_e5m2") torch.float8_e5m2 >>> str_to_dtype("fp8_e5m2fnuz") torch.float8_e5m2fnuz """ if s is None: return default_dtype if s in ["bf16", "bfloat16"]: return torch.bfloat16 elif s in ["fp16", "float16"]: return torch.float16 elif s in ["fp32", "float32", "float"]: return torch.float32 elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: return torch.float8_e4m3fn elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: return torch.float8_e4m3fnuz elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: return torch.float8_e5m2 elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: return torch.float8_e5m2fnuz elif s in ["fp8", "float8"]: return torch.float8_e4m3fn # default fp8 else: raise ValueError(f"Unsupported dtype: {s}") # endregion # region Image utils def pil_resize(image, size, interpolation): has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False if has_alpha: pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) else: pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) resized_pil = pil_image.resize(size, resample=interpolation) # Convert back to cv2 format if has_alpha: resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) else: resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 def resize_image( image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None, ): """ Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS. Args: image: numpy.ndarray width: int Original image width height: int Original image height resized_width: int Resized image width resized_height: int Resized image height resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box" Returns: image """ # Ensure all size parameters are actual integers width = int(width) height = int(height) resized_width = int(resized_width) resized_height = int(resized_height) if resize_interpolation is None: if width >= resized_width and height >= resized_height: resize_interpolation = "area" else: resize_interpolation = "lanczos" # we use PIL for lanczos (for backward compatibility) and box, cv2 for others use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"] resized_size = (resized_width, resized_height) if use_pil: interpolation = get_pil_interpolation(resize_interpolation) image = pil_resize(image, resized_size, interpolation=interpolation) logger.debug(f"resize image using {resize_interpolation} (PIL)") else: interpolation = get_cv2_interpolation(resize_interpolation) image = cv2.resize(image, resized_size, interpolation=interpolation) logger.debug(f"resize image using {resize_interpolation} (cv2)") return image def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: """ Convert interpolation value to cv2 interpolation integer https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121 """ if interpolation is None: return None if interpolation == "lanczos" or interpolation == "lanczos4": # Lanczos interpolation over 8x8 neighborhood return cv2.INTER_LANCZOS4 elif interpolation == "nearest": # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. return cv2.INTER_NEAREST_EXACT elif interpolation == "bilinear" or interpolation == "linear": # bilinear interpolation return cv2.INTER_LINEAR elif interpolation == "bicubic" or interpolation == "cubic": # bicubic interpolation return cv2.INTER_CUBIC elif interpolation == "area": # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. return cv2.INTER_AREA elif interpolation == "box": # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. return cv2.INTER_AREA else: return None def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]: """ Convert interpolation value to PIL interpolation https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters """ if interpolation is None: return None if interpolation == "lanczos": return Image.Resampling.LANCZOS elif interpolation == "nearest": # Pick one nearest pixel from the input image. Ignore all other input pixels. return Image.Resampling.NEAREST elif interpolation == "bilinear" or interpolation == "linear": # For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used. return Image.Resampling.BILINEAR elif interpolation == "bicubic" or interpolation == "cubic": # For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used. return Image.Resampling.BICUBIC elif interpolation == "area": # Image.Resampling.BOX may be more appropriate if upscaling # Area interpolation is related to cv2.INTER_AREA # Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX. return Image.Resampling.HAMMING elif interpolation == "box": # Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST. return Image.Resampling.BOX else: return None def validate_interpolation_fn(interpolation_str: str) -> bool: """ Check if a interpolation function is supported """ return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] # endregion # TODO make inf_utils.py # region Gradual Latent hires fix class GradualLatent: def __init__( self, ratio, start_timesteps, every_n_steps, ratio_step, s_noise=1.0, gaussian_blur_ksize=None, gaussian_blur_sigma=0.5, gaussian_blur_strength=0.5, unsharp_target_x=True, ): self.ratio = ratio self.start_timesteps = start_timesteps self.every_n_steps = every_n_steps self.ratio_step = ratio_step self.s_noise = s_noise self.gaussian_blur_ksize = gaussian_blur_ksize self.gaussian_blur_sigma = gaussian_blur_sigma self.gaussian_blur_strength = gaussian_blur_strength self.unsharp_target_x = unsharp_target_x def __str__(self) -> str: return ( f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, " + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, " + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, " + f"unsharp_target_x={self.unsharp_target_x})" ) def apply_unshark_mask(self, x: torch.Tensor): if self.gaussian_blur_ksize is None: return x blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma) # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength) mask = (x - blurred) * self.gaussian_blur_strength sharpened = x + mask return sharpened def interpolate(self, x: torch.Tensor, resized_size, unsharp=True): org_dtype = x.dtype if org_dtype == torch.bfloat16: x = x.float() x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype) # apply unsharp mask / アンシャープマスクを適用する if unsharp and self.gaussian_blur_ksize: x = self.apply_unshark_mask(x) return x class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.resized_size = None self.gradual_latent = None def set_gradual_latent_params(self, size, gradual_latent: GradualLatent): self.resized_size = size self.gradual_latent = gradual_latent def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: # logger.warning( print( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) sigma_from = self.sigmas[self.step_index] sigma_to = self.sigmas[self.step_index + 1] sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma dt = sigma_down - sigma device = model_output.device if self.resized_size is None: prev_sample = sample + derivative * dt noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( model_output.shape, dtype=model_output.dtype, device=device, generator=generator ) s_noise = 1.0 else: print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape) s_noise = self.gradual_latent.s_noise if self.gradual_latent.unsharp_target_x: prev_sample = sample + derivative * dt prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size) else: sample = self.gradual_latent.interpolate(sample, self.resized_size) derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False) prev_sample = sample + derivative * dt noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor( (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]), dtype=model_output.dtype, device=device, generator=generator, ) prev_sample = prev_sample + noise * sigma_up * s_noise # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # endregion