From bce9a081dbb2df0b41afea25bce5db12c511e8b8 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 5 Dec 2023 14:17:31 +0300 Subject: [PATCH 1/4] Update IPEX hijacks --- library/ipex/__init__.py | 1 + library/ipex/attention.py | 5 +++++ library/ipex/hijacks.py | 35 ++++++++++++++++++++++++++++------- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 43accd9f..dc1985ed 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -30,6 +30,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.FloatTensor = torch.xpu.FloatTensor torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 84848b6a..52016466 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -74,6 +74,11 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. shape_one, batch_size_attention, query_tokens, shape_four = query.shape no_shape_one = False + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + block_multiply = query.element_size() slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply block_size = batch_size_attention * slice_block_size diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 77ed5419..5c50c021 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -89,6 +89,7 @@ def ipex_autocast(*args, **kwargs): else: return original_autocast(*args, **kwargs) +#Embedding BF16 original_torch_cat = torch.cat def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): @@ -96,6 +97,7 @@ def torch_cat(tensor, *args, **kwargs): else: return original_torch_cat(tensor, *args, **kwargs) +#Latent antialias: original_interpolate = torch.nn.functional.interpolate def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: @@ -115,19 +117,28 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name else: return original_linalg_solve(A, B, *args, **kwargs) +def is_cuda(self): + return self.device.type == 'xpu' + def ipex_hijacks(): + CondFunc('torch.tensor', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) CondFunc('torch.Tensor.to', lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) CondFunc('torch.Tensor.cuda', lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.UntypedStorage.__init__', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.UntypedStorage.cuda', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) CondFunc('torch.empty', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.load', - lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), - lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) CondFunc('torch.randn', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) @@ -137,17 +148,19 @@ def ipex_hijacks(): CondFunc('torch.zeros', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.tensor', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) CondFunc('torch.linspace', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.load', + lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: + orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs), + lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location)) CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + #TiledVAE and ControlNet: CondFunc('torch.batch_norm', lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, weight if weight is not None else torch.ones(input.size()[1], device=input.device), @@ -163,17 +176,23 @@ def ipex_hijacks(): CondFunc('torch.nn.modules.GroupNorm.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + #Training: CondFunc('torch.nn.modules.linear.Linear.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + #BF16: CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight is not None and input.dtype != weight.data.dtype) + #SwinIR BF16: + CondFunc('torch.nn.functional.pad', + lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16), + lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16) #Diffusers Float64 (ARC GPUs doesn't support double or Float64): if not torch.xpu.has_fp64_dtype(): @@ -182,6 +201,7 @@ def ipex_hijacks(): lambda orig_func, ndarray: ndarray.dtype == float) #Broken functions when torch.cuda.is_available is True: + #Pin Memory: CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), lambda orig_func, *args, **kwargs: True) @@ -192,5 +212,6 @@ def ipex_hijacks(): torch.autocast = ipex_autocast torch.cat = torch_cat torch.linalg.solve = linalg_solve + torch.UntypedStorage.is_cuda = is_cuda torch.nn.functional.interpolate = interpolate torch.backends.cuda.sdp_kernel = return_null_context From 3d70137d31f23990beaf0e7f1bca54397bd09967 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 5 Dec 2023 19:40:16 +0300 Subject: [PATCH 2/4] Disable IPEX attention if the GPU supports 64 bit --- library/ipex/__init__.py | 13 +++++++------ library/ipex/diffusers.py | 2 +- library/ipex/gradscaler.py | 6 +++++- library/ipex/hijacks.py | 22 +++++++++++----------- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index dc1985ed..cda32ccb 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -165,12 +165,13 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card ipex_hijacks() - attention_init() - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + if not torch.xpu.has_fp64_dtype(): + attention_init() + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass except Exception as e: return False, e return True, None diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 005ee49f..c32af507 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,6 +1,6 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.21.1 # pylint: disable=import-error +import diffusers #0.24.0 # pylint: disable=import-error from diffusers.models.attention_processor import Attention # pylint: disable=protected-access, missing-function-docstring, line-too-long diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py index 53021210..6eb56bc2 100644 --- a/library/ipex/gradscaler.py +++ b/library/ipex/gradscaler.py @@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un # pylint: disable=protected-access, missing-function-docstring, line-too-long +device_supports_fp64 = torch.xpu.has_fp64_dtype() OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state @@ -96,7 +97,10 @@ def unscale_(self, optimizer): # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + if device_supports_fp64: + inv_scale = self._scale.double().reciprocal().float() + else: + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) found_inf = torch.full( (1,), 0.0, dtype=torch.float32, device=self._scale.device ) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 5c50c021..62d29605 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -89,7 +89,7 @@ def ipex_autocast(*args, **kwargs): else: return original_autocast(*args, **kwargs) -#Embedding BF16 +# Embedding BF16 original_torch_cat = torch.cat def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): @@ -97,7 +97,7 @@ def torch_cat(tensor, *args, **kwargs): else: return original_torch_cat(tensor, *args, **kwargs) -#Latent antialias: +# Latent antialias: original_interpolate = torch.nn.functional.interpolate def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: @@ -160,7 +160,7 @@ def ipex_hijacks(): lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") - #TiledVAE and ControlNet: + # TiledVAE and ControlNet: CondFunc('torch.batch_norm', lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, weight if weight is not None else torch.ones(input.size()[1], device=input.device), @@ -172,41 +172,41 @@ def ipex_hijacks(): bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - #Functions with dtype errors: + # Functions with dtype errors: CondFunc('torch.nn.modules.GroupNorm.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - #Training: + # Training: CondFunc('torch.nn.modules.linear.Linear.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - #BF16: + # BF16: CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight is not None and input.dtype != weight.data.dtype) - #SwinIR BF16: + # SwinIR BF16: CondFunc('torch.nn.functional.pad', lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16), lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16) - #Diffusers Float64 (ARC GPUs doesn't support double or Float64): + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): if not torch.xpu.has_fp64_dtype(): CondFunc('torch.from_numpy', lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), lambda orig_func, ndarray: ndarray.dtype == float) - #Broken functions when torch.cuda.is_available is True: - #Pin Memory: + # Broken functions when torch.cuda.is_available is True: + # Pin Memory: CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), lambda orig_func, *args, **kwargs: True) - #Functions that make compile mad with CondFunc: + # Functions that make compile mad with CondFunc: torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers torch.nn.DataParallel = DummyDataParallel torch.autocast = ipex_autocast From a9c6182b3fb61ad73375497f624e873e097242b8 Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 5 Dec 2023 19:52:31 +0300 Subject: [PATCH 3/4] Cleanup IPEX libs --- library/ipex/__init__.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index cda32ccb..662572c8 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -4,13 +4,12 @@ import contextlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from .hijacks import ipex_hijacks -from .attention import attention_init # pylint: disable=protected-access, missing-function-docstring, line-too-long def ipex_init(): # pylint: disable=too-many-statements try: - #Replace cuda with xpu: + # Replace cuda with xpu: torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.device = torch.xpu.device @@ -91,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.__file__ = torch.xpu.__file__ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - #Memory: + # Memory: torch.cuda.memory = torch.xpu.memory if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None @@ -113,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - #RNG: + # RNG: torch.cuda.get_rng_state = torch.xpu.get_rng_state torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all torch.cuda.set_rng_state = torch.xpu.set_rng_state @@ -124,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.seed_all = torch.xpu.seed_all torch.cuda.initial_seed = torch.xpu.initial_seed - #AMP: + # AMP: torch.cuda.amp = torch.xpu.amp if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() @@ -139,12 +138,12 @@ def ipex_init(): # pylint: disable=too-many-statements except Exception: # pylint: disable=broad-exception-caught torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - #C + # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream ipex._C._DeviceProperties.major = 2023 ipex._C._DeviceProperties.minor = 2 - #Fix functions with ipex: + # Fix functions with ipex: torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True @@ -166,7 +165,11 @@ def ipex_init(): # pylint: disable=too-many-statements ipex_hijacks() if not torch.xpu.has_fp64_dtype(): - attention_init() + try: + from .attention import attention_init + attention_init() + except Exception: # pylint: disable=broad-exception-caught + pass try: from .diffusers import ipex_diffusers ipex_diffusers() From dd7bb33ab60864cf86a93f6be03c0d92b62a1cdb Mon Sep 17 00:00:00 2001 From: Disty0 Date: Tue, 5 Dec 2023 22:18:47 +0300 Subject: [PATCH 4/4] IPEX fix torch.UntypedStorage.is_cuda --- library/ipex/hijacks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 62d29605..4a9a3569 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -117,6 +117,7 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name else: return original_linalg_solve(A, B, *args, **kwargs) +@property def is_cuda(self): return self.device.type == 'xpu'