diff --git a/XTI_hijack.py b/XTI_hijack.py index 1dbc263a..93bc1c0b 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,7 +1,7 @@ import torch -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex init_ipex() + from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/fine_tune.py b/fine_tune.py index 11e94e56..a6a5c1e2 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -8,11 +8,9 @@ from multiprocessing import Value import toml from tqdm import tqdm + import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 524f80b9..e7590faa 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -9,13 +9,16 @@ from pathlib import Path from PIL import Image from tqdm import tqdm import numpy as np + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torchvision import transforms from torchvision.transforms.functional import InterpolationMode sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder, is_url import library.train_util as train_util -from library.device_utils import get_preferred_device DEVICE = get_preferred_device() diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index 2b650eb0..549c477a 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -5,12 +5,15 @@ import re from pathlib import Path from PIL import Image from tqdm import tqdm + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from transformers import AutoProcessor, AutoModelForCausalLM from transformers.generation.utils import GenerationMixin import library.train_util as train_util -from library.device_utils import get_preferred_device DEVICE = get_preferred_device() diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 9d352dd6..29710ca9 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -8,14 +8,16 @@ from tqdm import tqdm import numpy as np from PIL import Image import cv2 + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torchvision import transforms import library.model_util as model_util import library.train_util as train_util -from library.device_utils import get_preferred_device - DEVICE = get_preferred_device() IMAGE_TRANSFORMS = transforms.Compose( diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 38b1ceab..ae79e2f9 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -64,11 +64,9 @@ import re import diffusers import numpy as np + import torch - -from library.device_utils import clean_memory, get_preferred_device -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory, get_preferred_device init_ipex() import torchvision diff --git a/library/device_utils.py b/library/device_utils.py index 353bfa9f..93371ca6 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -13,11 +13,19 @@ try: except Exception: HAS_MPS = False +try: + import intel_extension_for_pytorch as ipex # noqa + HAS_XPU = torch.xpu.is_available() +except Exception: + HAS_XPU = False + def clean_memory(): gc.collect() if HAS_CUDA: torch.cuda.empty_cache() + if HAS_XPU: + torch.xpu.empty_cache() if HAS_MPS: torch.mps.empty_cache() @@ -26,9 +34,30 @@ def clean_memory(): def get_preferred_device() -> torch.device: if HAS_CUDA: device = torch.device("cuda") + elif HAS_XPU: + device = torch.device("xpu") elif HAS_MPS: device = torch.device("mps") else: device = torch.device("cpu") print(f"get_preferred_device() -> {device}") return device + +def init_ipex(): + """ + Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. + + This function should run right after importing torch and before doing anything else. + + If IPEX is not available, this function does nothing. + """ + try: + if HAS_XPU: + from library.ipex import ipex_init + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + else: + return + except Exception as e: + print("failed to initialize ipex:", e) \ No newline at end of file diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 9f2e7c41..972a3bf6 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -9,167 +9,171 @@ from .hijacks import ipex_hijacks def ipex_init(): # pylint: disable=too-many-statements try: - # Replace cuda with xpu: - torch.cuda.current_device = torch.xpu.current_device - torch.cuda.current_stream = torch.xpu.current_stream - torch.cuda.device = torch.xpu.device - torch.cuda.device_count = torch.xpu.device_count - torch.cuda.device_of = torch.xpu.device_of - torch.cuda.get_device_name = torch.xpu.get_device_name - torch.cuda.get_device_properties = torch.xpu.get_device_properties - torch.cuda.init = torch.xpu.init - torch.cuda.is_available = torch.xpu.is_available - torch.cuda.is_initialized = torch.xpu.is_initialized - torch.cuda.is_current_stream_capturing = lambda: False - torch.cuda.set_device = torch.xpu.set_device - torch.cuda.stream = torch.xpu.stream - torch.cuda.synchronize = torch.xpu.synchronize - torch.cuda.Event = torch.xpu.Event - torch.cuda.Stream = torch.xpu.Stream - torch.cuda.FloatTensor = torch.xpu.FloatTensor - torch.Tensor.cuda = torch.Tensor.xpu - torch.Tensor.is_cuda = torch.Tensor.is_xpu - torch.UntypedStorage.cuda = torch.UntypedStorage.xpu - torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock - torch.cuda._initialized = torch.xpu.lazy_init._initialized - torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker - torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls - torch.cuda._tls = torch.xpu.lazy_init._tls - torch.cuda.threading = torch.xpu.lazy_init.threading - torch.cuda.traceback = torch.xpu.lazy_init.traceback - torch.cuda.Optional = torch.xpu.Optional - torch.cuda.__cached__ = torch.xpu.__cached__ - torch.cuda.__loader__ = torch.xpu.__loader__ - torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage - torch.cuda.Tuple = torch.xpu.Tuple - torch.cuda.streams = torch.xpu.streams - torch.cuda._lazy_new = torch.xpu._lazy_new - torch.cuda.FloatStorage = torch.xpu.FloatStorage - torch.cuda.Any = torch.xpu.Any - torch.cuda.__doc__ = torch.xpu.__doc__ - torch.cuda.default_generators = torch.xpu.default_generators - torch.cuda.HalfTensor = torch.xpu.HalfTensor - torch.cuda._get_device_index = torch.xpu._get_device_index - torch.cuda.__path__ = torch.xpu.__path__ - torch.cuda.Device = torch.xpu.Device - torch.cuda.IntTensor = torch.xpu.IntTensor - torch.cuda.ByteStorage = torch.xpu.ByteStorage - torch.cuda.set_stream = torch.xpu.set_stream - torch.cuda.BoolStorage = torch.xpu.BoolStorage - torch.cuda.os = torch.xpu.os - torch.cuda.torch = torch.xpu.torch - torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage - torch.cuda.Union = torch.xpu.Union - torch.cuda.DoubleTensor = torch.xpu.DoubleTensor - torch.cuda.ShortTensor = torch.xpu.ShortTensor - torch.cuda.LongTensor = torch.xpu.LongTensor - torch.cuda.IntStorage = torch.xpu.IntStorage - torch.cuda.LongStorage = torch.xpu.LongStorage - torch.cuda.__annotations__ = torch.xpu.__annotations__ - torch.cuda.__package__ = torch.xpu.__package__ - torch.cuda.__builtins__ = torch.xpu.__builtins__ - torch.cuda.CharTensor = torch.xpu.CharTensor - torch.cuda.List = torch.xpu.List - torch.cuda._lazy_init = torch.xpu._lazy_init - torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor - torch.cuda.DoubleStorage = torch.xpu.DoubleStorage - torch.cuda.ByteTensor = torch.xpu.ByteTensor - torch.cuda.StreamContext = torch.xpu.StreamContext - torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage - torch.cuda.ShortStorage = torch.xpu.ShortStorage - torch.cuda._lazy_call = torch.xpu._lazy_call - torch.cuda.HalfStorage = torch.xpu.HalfStorage - torch.cuda.random = torch.xpu.random - torch.cuda._device = torch.xpu._device - torch.cuda.classproperty = torch.xpu.classproperty - torch.cuda.__name__ = torch.xpu.__name__ - torch.cuda._device_t = torch.xpu._device_t - torch.cuda.warnings = torch.xpu.warnings - torch.cuda.__spec__ = torch.xpu.__spec__ - torch.cuda.BoolTensor = torch.xpu.BoolTensor - torch.cuda.CharStorage = torch.xpu.CharStorage - torch.cuda.__file__ = torch.xpu.__file__ - torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: + return True, "Skipping IPEX hijack" + else: + # Replace cuda with xpu: + torch.cuda.current_device = torch.xpu.current_device + torch.cuda.current_stream = torch.xpu.current_stream + torch.cuda.device = torch.xpu.device + torch.cuda.device_count = torch.xpu.device_count + torch.cuda.device_of = torch.xpu.device_of + torch.cuda.get_device_name = torch.xpu.get_device_name + torch.cuda.get_device_properties = torch.xpu.get_device_properties + torch.cuda.init = torch.xpu.init + torch.cuda.is_available = torch.xpu.is_available + torch.cuda.is_initialized = torch.xpu.is_initialized + torch.cuda.is_current_stream_capturing = lambda: False + torch.cuda.set_device = torch.xpu.set_device + torch.cuda.stream = torch.xpu.stream + torch.cuda.synchronize = torch.xpu.synchronize + torch.cuda.Event = torch.xpu.Event + torch.cuda.Stream = torch.xpu.Stream + torch.cuda.FloatTensor = torch.xpu.FloatTensor + torch.Tensor.cuda = torch.Tensor.xpu + torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu + torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock + torch.cuda._initialized = torch.xpu.lazy_init._initialized + torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker + torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls + torch.cuda._tls = torch.xpu.lazy_init._tls + torch.cuda.threading = torch.xpu.lazy_init.threading + torch.cuda.traceback = torch.xpu.lazy_init.traceback + torch.cuda.Optional = torch.xpu.Optional + torch.cuda.__cached__ = torch.xpu.__cached__ + torch.cuda.__loader__ = torch.xpu.__loader__ + torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage + torch.cuda.Tuple = torch.xpu.Tuple + torch.cuda.streams = torch.xpu.streams + torch.cuda._lazy_new = torch.xpu._lazy_new + torch.cuda.FloatStorage = torch.xpu.FloatStorage + torch.cuda.Any = torch.xpu.Any + torch.cuda.__doc__ = torch.xpu.__doc__ + torch.cuda.default_generators = torch.xpu.default_generators + torch.cuda.HalfTensor = torch.xpu.HalfTensor + torch.cuda._get_device_index = torch.xpu._get_device_index + torch.cuda.__path__ = torch.xpu.__path__ + torch.cuda.Device = torch.xpu.Device + torch.cuda.IntTensor = torch.xpu.IntTensor + torch.cuda.ByteStorage = torch.xpu.ByteStorage + torch.cuda.set_stream = torch.xpu.set_stream + torch.cuda.BoolStorage = torch.xpu.BoolStorage + torch.cuda.os = torch.xpu.os + torch.cuda.torch = torch.xpu.torch + torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage + torch.cuda.Union = torch.xpu.Union + torch.cuda.DoubleTensor = torch.xpu.DoubleTensor + torch.cuda.ShortTensor = torch.xpu.ShortTensor + torch.cuda.LongTensor = torch.xpu.LongTensor + torch.cuda.IntStorage = torch.xpu.IntStorage + torch.cuda.LongStorage = torch.xpu.LongStorage + torch.cuda.__annotations__ = torch.xpu.__annotations__ + torch.cuda.__package__ = torch.xpu.__package__ + torch.cuda.__builtins__ = torch.xpu.__builtins__ + torch.cuda.CharTensor = torch.xpu.CharTensor + torch.cuda.List = torch.xpu.List + torch.cuda._lazy_init = torch.xpu._lazy_init + torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor + torch.cuda.DoubleStorage = torch.xpu.DoubleStorage + torch.cuda.ByteTensor = torch.xpu.ByteTensor + torch.cuda.StreamContext = torch.xpu.StreamContext + torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage + torch.cuda.ShortStorage = torch.xpu.ShortStorage + torch.cuda._lazy_call = torch.xpu._lazy_call + torch.cuda.HalfStorage = torch.xpu.HalfStorage + torch.cuda.random = torch.xpu.random + torch.cuda._device = torch.xpu._device + torch.cuda.classproperty = torch.xpu.classproperty + torch.cuda.__name__ = torch.xpu.__name__ + torch.cuda._device_t = torch.xpu._device_t + torch.cuda.warnings = torch.xpu.warnings + torch.cuda.__spec__ = torch.xpu.__spec__ + torch.cuda.BoolTensor = torch.xpu.BoolTensor + torch.cuda.CharStorage = torch.xpu.CharStorage + torch.cuda.__file__ = torch.xpu.__file__ + torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - # Memory: - torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): - torch.xpu.empty_cache = lambda: None - torch.cuda.empty_cache = torch.xpu.empty_cache - torch.cuda.memory_stats = torch.xpu.memory_stats - torch.cuda.memory_summary = torch.xpu.memory_summary - torch.cuda.memory_snapshot = torch.xpu.memory_snapshot - torch.cuda.memory_allocated = torch.xpu.memory_allocated - torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated - torch.cuda.memory_reserved = torch.xpu.memory_reserved - torch.cuda.memory_cached = torch.xpu.memory_reserved - torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved - torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved - torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats - torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats - torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + # Memory: + torch.cuda.memory = torch.xpu.memory + if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + torch.xpu.empty_cache = lambda: None + torch.cuda.empty_cache = torch.xpu.empty_cache + torch.cuda.memory_stats = torch.xpu.memory_stats + torch.cuda.memory_summary = torch.xpu.memory_summary + torch.cuda.memory_snapshot = torch.xpu.memory_snapshot + torch.cuda.memory_allocated = torch.xpu.memory_allocated + torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated + torch.cuda.memory_reserved = torch.xpu.memory_reserved + torch.cuda.memory_cached = torch.xpu.memory_reserved + torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved + torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved + torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats + torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats + torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict + torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - # RNG: - torch.cuda.get_rng_state = torch.xpu.get_rng_state - torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all - torch.cuda.set_rng_state = torch.xpu.set_rng_state - torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all - torch.cuda.manual_seed = torch.xpu.manual_seed - torch.cuda.manual_seed_all = torch.xpu.manual_seed_all - torch.cuda.seed = torch.xpu.seed - torch.cuda.seed_all = torch.xpu.seed_all - torch.cuda.initial_seed = torch.xpu.initial_seed + # RNG: + torch.cuda.get_rng_state = torch.xpu.get_rng_state + torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all + torch.cuda.set_rng_state = torch.xpu.set_rng_state + torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all + torch.cuda.manual_seed = torch.xpu.manual_seed + torch.cuda.manual_seed_all = torch.xpu.manual_seed_all + torch.cuda.seed = torch.xpu.seed + torch.cuda.seed_all = torch.xpu.seed_all + torch.cuda.initial_seed = torch.xpu.initial_seed - # AMP: - torch.cuda.amp = torch.xpu.amp - torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled - torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + # AMP: + torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype - if not hasattr(torch.cuda.amp, "common"): - torch.cuda.amp.common = contextlib.nullcontext() - torch.cuda.amp.common.amp_definitely_not_available = lambda: False + if not hasattr(torch.cuda.amp, "common"): + torch.cuda.amp.common = contextlib.nullcontext() + torch.cuda.amp.common.amp_definitely_not_available = lambda: False - try: - torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error - gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught - torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler + try: + from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() + torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler + except Exception: # pylint: disable=broad-exception-caught + torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - # C - torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream - ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count - ipex._C._DeviceProperties.major = 2023 - ipex._C._DeviceProperties.minor = 2 + # C + torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream + ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count + ipex._C._DeviceProperties.major = 2023 + ipex._C._DeviceProperties.minor = 2 - # Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] - torch._utils._get_available_device_type = lambda: "xpu" - torch.has_cuda = True - torch.cuda.has_half = True - torch.cuda.is_bf16_supported = lambda *args, **kwargs: True - torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.backends.cuda.is_built = lambda *args, **kwargs: True - torch.version.cuda = "12.1" - torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] - torch.cuda.get_device_properties.major = 12 - torch.cuda.get_device_properties.minor = 1 - torch.cuda.ipc_collect = lambda *args, **kwargs: None - torch.cuda.utilization = lambda *args, **kwargs: 0 + # Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] + torch._utils._get_available_device_type = lambda: "xpu" + torch.has_cuda = True + torch.cuda.has_half = True + torch.cuda.is_bf16_supported = lambda *args, **kwargs: True + torch.cuda.is_fp16_supported = lambda *args, **kwargs: True + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 + torch.cuda.ipc_collect = lambda *args, **kwargs: None + torch.cuda.utilization = lambda *args, **kwargs: 0 - ipex_hijacks() - if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + ipex_hijacks() + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass + torch.cuda.is_xpu_hijacked = True except Exception as e: return False, e return True, None diff --git a/library/ipex_interop.py b/library/ipex_interop.py deleted file mode 100644 index 6fe320c5..00000000 --- a/library/ipex_interop.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch - - -def init_ipex(): - """ - Try to import `intel_extension_for_pytorch`, and apply - the hijacks using `library.ipex.ipex_init`. - - If IPEX is not installed, this function does nothing. - """ - try: - import intel_extension_for_pytorch as ipex # noqa - except ImportError: - return - - try: - from library.ipex import ipex_init - - if torch.xpu.is_available(): - is_initialized, error_message = ipex_init() - if not is_initialized: - print("failed to initialize ipex:", error_message) - except Exception as e: - print("failed to initialize ipex:", e) diff --git a/library/model_util.py b/library/model_util.py index 4361b499..6398e8a0 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -3,11 +3,11 @@ import math import os + import torch - -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex init_ipex() + import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index d2becad6..0ecf4feb 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -2,12 +2,15 @@ import argparse import math import os from typing import Optional + import torch +from library.device_utils import init_ipex, clean_memory +init_ipex() + from accelerate import init_empty_weights from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.device_utils import clean_memory from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline TOKENIZER1_PATH = "openai/clip-vit-large-patch14" diff --git a/library/train_util.py b/library/train_util.py index d59f4258..3b74ddce 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -30,7 +30,11 @@ from io import BytesIO import toml from tqdm import tqdm + import torch +from library.device_utils import init_ipex, clean_memory +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torchvision import transforms @@ -66,7 +70,6 @@ import library.sai_model_spec as sai_model_spec # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork -from library.device_utils import clean_memory from library.original_unet import UNet2DConditionModel # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 0056ac78..cbba44f7 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -9,9 +9,10 @@ from diffusers import UNet2DConditionModel import numpy as np from tqdm import tqdm from transformers import CLIPTextModel -import torch -from library.device_utils import get_preferred_device +import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() def make_unet_conversion_map() -> Dict[str, str]: diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 83942d7c..6cda391b 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -5,11 +5,13 @@ from library import model_util import library.train_util as train_util import argparse from transformers import CLIPTokenizer + import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() import library.model_util as model_util import lora -from library.device_utils import get_preferred_device TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 0722b93f..8824a3a1 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -16,11 +16,9 @@ import re import diffusers import numpy as np + import torch - -from library.device_utils import clean_memory, get_preferred_device -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory, get_preferred_device init_ipex() import torchvision diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 3eae2044..4f509b6d 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -8,11 +8,9 @@ import os import random from einops import repeat import numpy as np + import torch - -from library.device_utils import get_preferred_device -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, get_preferred_device init_ipex() from tqdm import tqdm diff --git a/sdxl_train.py b/sdxl_train.py index 78cfaf49..b0dcdbe9 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -8,11 +8,9 @@ from typing import List import toml from tqdm import tqdm + import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 95b755f1..1b069624 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -12,11 +12,9 @@ from types import SimpleNamespace import toml from tqdm import tqdm + import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index fd24898c..b74e3b90 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -9,11 +9,9 @@ from types import SimpleNamespace import toml from tqdm import tqdm + import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP diff --git a/sdxl_train_network.py b/sdxl_train_network.py index af0c8d1d..205b526e 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,9 +1,7 @@ import argparse + import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index df393713..b9a948bb 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -2,10 +2,11 @@ import argparse import os import regex -import torch -from library.ipex_interop import init_ipex +import torch +from library.device_utils import init_ipex init_ipex() + import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index 27d13ef6..81f51f02 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -11,12 +11,13 @@ from typing import Dict, List import numpy as np import torch +from library.device_utils import init_ipex, get_preferred_device +init_ipex() + from torch import nn from tqdm import tqdm from PIL import Image -from library.device_utils import get_preferred_device - class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): diff --git a/train_controlnet.py b/train_controlnet.py index e6bea2c9..e7a06ae1 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -9,11 +9,9 @@ from types import SimpleNamespace import toml from tqdm import tqdm + import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP diff --git a/train_db.py b/train_db.py index daeb6d66..f6795dce 100644 --- a/train_db.py +++ b/train_db.py @@ -9,11 +9,9 @@ from multiprocessing import Value import toml from tqdm import tqdm + import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed diff --git a/train_network.py b/train_network.py index 9aabd4d7..f8dd6ab2 100644 --- a/train_network.py +++ b/train_network.py @@ -10,14 +10,13 @@ from multiprocessing import Value import toml from tqdm import tqdm + import torch -from torch.nn.parallel import DistributedDataParallel as DDP - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() +from torch.nn.parallel import DistributedDataParallel as DDP + from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 821cfe78..d18837bd 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -5,11 +5,9 @@ from multiprocessing import Value import toml from tqdm import tqdm + import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ecd6d087..9d4b0aef 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -6,11 +6,9 @@ import toml from multiprocessing import Value from tqdm import tqdm + import torch - -from library.device_utils import clean_memory -from library.ipex_interop import init_ipex - +from library.device_utils import init_ipex, clean_memory init_ipex() from accelerate.utils import set_seed