mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Deduplicate ipex initialization code
This commit is contained in:
@@ -1,11 +1,7 @@
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
|
||||
|
||||
@@ -11,15 +11,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
|
||||
@@ -66,15 +66,10 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
import torchvision
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
||||
24
library/ipex_interop.py
Normal file
24
library/ipex_interop.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
def init_ipex():
|
||||
"""
|
||||
Try to import `intel_extension_for_pytorch`, and apply
|
||||
the hijacks using `library.ipex.ipex_init`.
|
||||
|
||||
If IPEX is not installed, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
try:
|
||||
from library.ipex import ipex_init
|
||||
|
||||
if torch.xpu.is_available():
|
||||
is_initialized, error_message = ipex_init()
|
||||
if not is_initialized:
|
||||
print("failed to initialize ipex:", error_message)
|
||||
except Exception as e:
|
||||
print("failed to initialize ipex:", e)
|
||||
@@ -5,15 +5,9 @@ import math
|
||||
import os
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
init_ipex()
|
||||
import diffusers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||
|
||||
@@ -18,15 +18,10 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
import torchvision
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
||||
@@ -9,13 +9,11 @@ import random
|
||||
from einops import repeat
|
||||
import numpy as np
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
@@ -11,15 +11,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import sdxl_model_util
|
||||
|
||||
@@ -14,13 +14,11 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
import accelerate
|
||||
|
||||
@@ -11,13 +11,11 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
import train_network
|
||||
|
||||
|
||||
@@ -3,13 +3,9 @@ import os
|
||||
|
||||
import regex
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
import open_clip
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
|
||||
|
||||
@@ -12,15 +12,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
|
||||
@@ -12,15 +12,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
|
||||
@@ -14,15 +14,10 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import model_util
|
||||
|
||||
@@ -8,15 +8,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
@@ -8,13 +8,11 @@ from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
Reference in New Issue
Block a user