diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 33350493..9f2e7c41 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -125,9 +125,13 @@ def ipex_init(): # pylint: disable=too-many-statements # AMP: torch.cuda.amp = torch.xpu.amp + torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled + torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype + if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() torch.cuda.amp.common.amp_definitely_not_available = lambda: False + try: torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler except Exception: # pylint: disable=broad-exception-caught @@ -151,15 +155,16 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.has_half = True torch.cuda.is_bf16_supported = lambda *args, **kwargs: True torch.cuda.is_fp16_supported = lambda *args, **kwargs: True - torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] - torch.cuda.get_device_properties.major = 11 - torch.cuda.get_device_properties.minor = 7 + torch.backends.cuda.is_built = lambda *args, **kwargs: True + torch.version.cuda = "12.1" + torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] + torch.cuda.get_device_properties.major = 12 + torch.cuda.get_device_properties.minor = 1 torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.utilization = lambda *args, **kwargs: 0 ipex_hijacks() - if not torch.xpu.has_fp64_dtype(): + if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None: try: from .diffusers import ipex_diffusers ipex_diffusers() diff --git a/library/ipex/attention.py b/library/ipex/attention.py index e98807a8..8253c5b1 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None): ) else: return original_torch_bmm(input, mat2, out=out) + torch.xpu.synchronize(input.device) return hidden_states original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention @@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo ) else: return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) + torch.xpu.synchronize(query.device) return hidden_states diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 47b0375a..732a1856 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -149,6 +149,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] @@ -283,6 +284,7 @@ class AttnProcessor: hidden_states[start_idx:end_idx] = attn_slice del attn_slice + torch.xpu.synchronize(query.device) else: attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index b6d246dd..7b2d26d4 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -1,6 +1,11 @@ -import contextlib +import os +from functools import wraps +from contextlib import nullcontext import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import numpy as np + +device_supports_fp64 = torch.xpu.has_fp64_dtype() # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return @@ -11,7 +16,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr return module.to("xpu") def return_null_context(*args, **kwargs): # pylint: disable=unused-argument - return contextlib.nullcontext() + return nullcontext() @property def is_cuda(self): @@ -25,15 +30,17 @@ def return_xpu(device): # Autocast -original_autocast = torch.autocast -def ipex_autocast(*args, **kwargs): - if len(args) > 0 and args[0] == "cuda": - return original_autocast("xpu", *args[1:], **kwargs) +original_autocast_init = torch.amp.autocast_mode.autocast.__init__ +@wraps(torch.amp.autocast_mode.autocast.__init__) +def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None): + if device_type == "cuda": + return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) else: - return original_autocast(*args, **kwargs) + return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) # Latent Antialias CPU Offload: original_interpolate = torch.nn.functional.interpolate +@wraps(torch.nn.functional.interpolate) def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: return_device = tensor.device @@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): original_from_numpy = torch.from_numpy +@wraps(torch.from_numpy) def from_numpy(ndarray): if ndarray.dtype == float: return original_from_numpy(ndarray.astype('float32')) else: return original_from_numpy(ndarray) -if torch.xpu.has_fp64_dtype(): +original_as_tensor = torch.as_tensor +@wraps(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if check_device(device): + device = return_xpu(device) + if isinstance(data, np.ndarray) and data.dtype == float and not ( + (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)): + return original_as_tensor(data, dtype=torch.float32, device=device) + else: + return original_as_tensor(data, dtype=dtype, device=device) + + +if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: original_torch_bmm = torch.bmm original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention else: @@ -66,20 +87,25 @@ else: # Data Type Errors: +@wraps(torch.bmm) def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) return original_torch_bmm(input, mat2, out=out) +@wraps(torch.nn.functional.scaled_dot_product_attention) def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): if query.dtype != key.dtype: key = key.to(dtype=query.dtype) if query.dtype != value.dtype: value = value.to(dtype=query.dtype) + if attn_mask is not None and query.dtype != attn_mask.dtype: + attn_mask = attn_mask.to(dtype=query.dtype) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) # A1111 FP16 original_functional_group_norm = torch.nn.functional.group_norm +@wraps(torch.nn.functional.group_norm) def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): # A1111 BF16 original_functional_layer_norm = torch.nn.functional.layer_norm +@wraps(torch.nn.functional.layer_norm) def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): if weight is not None and input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1 # Training original_functional_linear = torch.nn.functional.linear +@wraps(torch.nn.functional.linear) def functional_linear(input, weight, bias=None): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None): return original_functional_linear(input, weight, bias=bias) original_functional_conv2d = torch.nn.functional.conv2d +@wraps(torch.nn.functional.conv2d) def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if input.dtype != weight.data.dtype: input = input.to(dtype=weight.data.dtype) @@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, # A1111 Embedding BF16 original_torch_cat = torch.cat +@wraps(torch.cat) def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) @@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs): # SwinIR BF16: original_functional_pad = torch.nn.functional.pad +@wraps(torch.nn.functional.pad) def functional_pad(input, pad, mode='constant', value=None): if mode == 'reflect' and input.dtype == torch.bfloat16: return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) @@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None): original_torch_tensor = torch.tensor -def torch_tensor(*args, device=None, **kwargs): +@wraps(torch.tensor) +def torch_tensor(data, *args, dtype=None, device=None, **kwargs): if check_device(device): - return original_torch_tensor(*args, device=return_xpu(device), **kwargs) - else: - return original_torch_tensor(*args, device=device, **kwargs) + device = return_xpu(device) + if not device_supports_fp64: + if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device): + if dtype == torch.float64: + dtype = torch.float32 + elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)): + dtype = torch.float32 + return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs) original_Tensor_to = torch.Tensor.to +@wraps(torch.Tensor.to) def Tensor_to(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_to(self, return_xpu(device), *args, **kwargs) @@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs): return original_Tensor_to(self, device, *args, **kwargs) original_Tensor_cuda = torch.Tensor.cuda +@wraps(torch.Tensor.cuda) def Tensor_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) @@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs): return original_Tensor_cuda(self, device, *args, **kwargs) original_UntypedStorage_init = torch.UntypedStorage.__init__ +@wraps(torch.UntypedStorage.__init__) def UntypedStorage_init(*args, device=None, **kwargs): if check_device(device): return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) @@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs): return original_UntypedStorage_init(*args, device=device, **kwargs) original_UntypedStorage_cuda = torch.UntypedStorage.cuda +@wraps(torch.UntypedStorage.cuda) def UntypedStorage_cuda(self, device=None, *args, **kwargs): if check_device(device): return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) @@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs): return original_UntypedStorage_cuda(self, device, *args, **kwargs) original_torch_empty = torch.empty +@wraps(torch.empty) def torch_empty(*args, device=None, **kwargs): if check_device(device): return original_torch_empty(*args, device=return_xpu(device), **kwargs) @@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs): return original_torch_empty(*args, device=device, **kwargs) original_torch_randn = torch.randn +@wraps(torch.randn) def torch_randn(*args, device=None, **kwargs): if check_device(device): return original_torch_randn(*args, device=return_xpu(device), **kwargs) @@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs): return original_torch_randn(*args, device=device, **kwargs) original_torch_ones = torch.ones +@wraps(torch.ones) def torch_ones(*args, device=None, **kwargs): if check_device(device): return original_torch_ones(*args, device=return_xpu(device), **kwargs) @@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs): return original_torch_ones(*args, device=device, **kwargs) original_torch_zeros = torch.zeros +@wraps(torch.zeros) def torch_zeros(*args, device=None, **kwargs): if check_device(device): return original_torch_zeros(*args, device=return_xpu(device), **kwargs) @@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs): return original_torch_zeros(*args, device=device, **kwargs) original_torch_linspace = torch.linspace +@wraps(torch.linspace) def torch_linspace(*args, device=None, **kwargs): if check_device(device): return original_torch_linspace(*args, device=return_xpu(device), **kwargs) @@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs): return original_torch_linspace(*args, device=device, **kwargs) original_torch_Generator = torch.Generator +@wraps(torch.Generator) def torch_Generator(device=None): if check_device(device): return original_torch_Generator(return_xpu(device)) @@ -208,12 +255,14 @@ def torch_Generator(device=None): return original_torch_Generator(device) original_torch_load = torch.load +@wraps(torch.load) def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): if check_device(map_location): return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) else: return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) + # Hijack Functions: def ipex_hijacks(): torch.tensor = torch_tensor @@ -232,7 +281,7 @@ def ipex_hijacks(): torch.backends.cuda.sdp_kernel = return_null_context torch.nn.DataParallel = DummyDataParallel torch.UntypedStorage.is_cuda = is_cuda - torch.autocast = ipex_autocast + torch.amp.autocast_mode.autocast.__init__ = autocast_init torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention torch.nn.functional.group_norm = functional_group_norm @@ -244,5 +293,6 @@ def ipex_hijacks(): torch.bmm = torch_bmm torch.cat = torch_cat - if not torch.xpu.has_fp64_dtype(): + if not device_supports_fp64: torch.from_numpy = from_numpy + torch.as_tensor = as_tensor