|
|
|
|
@@ -1,6 +1,11 @@
|
|
|
|
|
import contextlib
|
|
|
|
|
import os
|
|
|
|
|
from functools import wraps
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
import torch
|
|
|
|
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
|
|
|
|
|
|
|
|
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
|
|
|
|
|
|
|
|
|
@@ -11,7 +16,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
|
|
|
|
|
return module.to("xpu")
|
|
|
|
|
|
|
|
|
|
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
|
|
|
|
return contextlib.nullcontext()
|
|
|
|
|
return nullcontext()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def is_cuda(self):
|
|
|
|
|
@@ -25,15 +30,17 @@ def return_xpu(device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Autocast
|
|
|
|
|
original_autocast = torch.autocast
|
|
|
|
|
def ipex_autocast(*args, **kwargs):
|
|
|
|
|
if len(args) > 0 and args[0] == "cuda":
|
|
|
|
|
return original_autocast("xpu", *args[1:], **kwargs)
|
|
|
|
|
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
|
|
|
|
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
|
|
|
|
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
|
|
|
|
if device_type == "cuda":
|
|
|
|
|
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
|
|
|
|
else:
|
|
|
|
|
return original_autocast(*args, **kwargs)
|
|
|
|
|
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
|
|
|
|
|
|
|
|
|
# Latent Antialias CPU Offload:
|
|
|
|
|
original_interpolate = torch.nn.functional.interpolate
|
|
|
|
|
@wraps(torch.nn.functional.interpolate)
|
|
|
|
|
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
|
|
|
|
if antialias or align_corners is not None:
|
|
|
|
|
return_device = tensor.device
|
|
|
|
|
@@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
|
|
|
|
|
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
|
|
|
|
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
|
|
|
|
original_from_numpy = torch.from_numpy
|
|
|
|
|
@wraps(torch.from_numpy)
|
|
|
|
|
def from_numpy(ndarray):
|
|
|
|
|
if ndarray.dtype == float:
|
|
|
|
|
return original_from_numpy(ndarray.astype('float32'))
|
|
|
|
|
else:
|
|
|
|
|
return original_from_numpy(ndarray)
|
|
|
|
|
|
|
|
|
|
if torch.xpu.has_fp64_dtype():
|
|
|
|
|
original_as_tensor = torch.as_tensor
|
|
|
|
|
@wraps(torch.as_tensor)
|
|
|
|
|
def as_tensor(data, dtype=None, device=None):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
device = return_xpu(device)
|
|
|
|
|
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
|
|
|
|
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
|
|
|
|
return original_as_tensor(data, dtype=torch.float32, device=device)
|
|
|
|
|
else:
|
|
|
|
|
return original_as_tensor(data, dtype=dtype, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
|
|
|
|
original_torch_bmm = torch.bmm
|
|
|
|
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
|
|
|
|
else:
|
|
|
|
|
@@ -66,20 +87,25 @@ else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Data Type Errors:
|
|
|
|
|
@wraps(torch.bmm)
|
|
|
|
|
def torch_bmm(input, mat2, *, out=None):
|
|
|
|
|
if input.dtype != mat2.dtype:
|
|
|
|
|
mat2 = mat2.to(input.dtype)
|
|
|
|
|
return original_torch_bmm(input, mat2, out=out)
|
|
|
|
|
|
|
|
|
|
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
|
|
|
|
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
|
|
|
|
if query.dtype != key.dtype:
|
|
|
|
|
key = key.to(dtype=query.dtype)
|
|
|
|
|
if query.dtype != value.dtype:
|
|
|
|
|
value = value.to(dtype=query.dtype)
|
|
|
|
|
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
|
|
|
|
attn_mask = attn_mask.to(dtype=query.dtype)
|
|
|
|
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
|
|
|
|
|
|
|
|
|
# A1111 FP16
|
|
|
|
|
original_functional_group_norm = torch.nn.functional.group_norm
|
|
|
|
|
@wraps(torch.nn.functional.group_norm)
|
|
|
|
|
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
|
|
|
|
if weight is not None and input.dtype != weight.data.dtype:
|
|
|
|
|
input = input.to(dtype=weight.data.dtype)
|
|
|
|
|
@@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
|
|
|
|
|
|
|
|
|
# A1111 BF16
|
|
|
|
|
original_functional_layer_norm = torch.nn.functional.layer_norm
|
|
|
|
|
@wraps(torch.nn.functional.layer_norm)
|
|
|
|
|
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
|
|
|
|
if weight is not None and input.dtype != weight.data.dtype:
|
|
|
|
|
input = input.to(dtype=weight.data.dtype)
|
|
|
|
|
@@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1
|
|
|
|
|
|
|
|
|
|
# Training
|
|
|
|
|
original_functional_linear = torch.nn.functional.linear
|
|
|
|
|
@wraps(torch.nn.functional.linear)
|
|
|
|
|
def functional_linear(input, weight, bias=None):
|
|
|
|
|
if input.dtype != weight.data.dtype:
|
|
|
|
|
input = input.to(dtype=weight.data.dtype)
|
|
|
|
|
@@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None):
|
|
|
|
|
return original_functional_linear(input, weight, bias=bias)
|
|
|
|
|
|
|
|
|
|
original_functional_conv2d = torch.nn.functional.conv2d
|
|
|
|
|
@wraps(torch.nn.functional.conv2d)
|
|
|
|
|
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|
|
|
|
if input.dtype != weight.data.dtype:
|
|
|
|
|
input = input.to(dtype=weight.data.dtype)
|
|
|
|
|
@@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
|
|
|
|
|
|
|
|
|
|
# A1111 Embedding BF16
|
|
|
|
|
original_torch_cat = torch.cat
|
|
|
|
|
@wraps(torch.cat)
|
|
|
|
|
def torch_cat(tensor, *args, **kwargs):
|
|
|
|
|
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
|
|
|
|
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
|
|
|
|
@@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
# SwinIR BF16:
|
|
|
|
|
original_functional_pad = torch.nn.functional.pad
|
|
|
|
|
@wraps(torch.nn.functional.pad)
|
|
|
|
|
def functional_pad(input, pad, mode='constant', value=None):
|
|
|
|
|
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
|
|
|
|
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
|
|
|
|
@@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
original_torch_tensor = torch.tensor
|
|
|
|
|
def torch_tensor(*args, device=None, **kwargs):
|
|
|
|
|
@wraps(torch.tensor)
|
|
|
|
|
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return original_torch_tensor(*args, device=device, **kwargs)
|
|
|
|
|
device = return_xpu(device)
|
|
|
|
|
if not device_supports_fp64:
|
|
|
|
|
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
|
|
|
|
if dtype == torch.float64:
|
|
|
|
|
dtype = torch.float32
|
|
|
|
|
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
|
|
|
|
dtype = torch.float32
|
|
|
|
|
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_Tensor_to = torch.Tensor.to
|
|
|
|
|
@wraps(torch.Tensor.to)
|
|
|
|
|
def Tensor_to(self, device=None, *args, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
|
|
|
|
@@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs):
|
|
|
|
|
return original_Tensor_to(self, device, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_Tensor_cuda = torch.Tensor.cuda
|
|
|
|
|
@wraps(torch.Tensor.cuda)
|
|
|
|
|
def Tensor_cuda(self, device=None, *args, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
|
|
|
|
@@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
|
|
|
|
|
return original_Tensor_cuda(self, device, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
|
|
|
|
@wraps(torch.UntypedStorage.__init__)
|
|
|
|
|
def UntypedStorage_init(*args, device=None, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
|
|
|
|
@@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs):
|
|
|
|
|
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
|
|
|
|
@wraps(torch.UntypedStorage.cuda)
|
|
|
|
|
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
|
|
|
|
@@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
|
|
|
|
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_torch_empty = torch.empty
|
|
|
|
|
@wraps(torch.empty)
|
|
|
|
|
def torch_empty(*args, device=None, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
|
|
|
|
@@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs):
|
|
|
|
|
return original_torch_empty(*args, device=device, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_torch_randn = torch.randn
|
|
|
|
|
@wraps(torch.randn)
|
|
|
|
|
def torch_randn(*args, device=None, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
|
|
|
|
@@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs):
|
|
|
|
|
return original_torch_randn(*args, device=device, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_torch_ones = torch.ones
|
|
|
|
|
@wraps(torch.ones)
|
|
|
|
|
def torch_ones(*args, device=None, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
|
|
|
|
@@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs):
|
|
|
|
|
return original_torch_ones(*args, device=device, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_torch_zeros = torch.zeros
|
|
|
|
|
@wraps(torch.zeros)
|
|
|
|
|
def torch_zeros(*args, device=None, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
|
|
|
|
@@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs):
|
|
|
|
|
return original_torch_zeros(*args, device=device, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_torch_linspace = torch.linspace
|
|
|
|
|
@wraps(torch.linspace)
|
|
|
|
|
def torch_linspace(*args, device=None, **kwargs):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
|
|
|
|
@@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs):
|
|
|
|
|
return original_torch_linspace(*args, device=device, **kwargs)
|
|
|
|
|
|
|
|
|
|
original_torch_Generator = torch.Generator
|
|
|
|
|
@wraps(torch.Generator)
|
|
|
|
|
def torch_Generator(device=None):
|
|
|
|
|
if check_device(device):
|
|
|
|
|
return original_torch_Generator(return_xpu(device))
|
|
|
|
|
@@ -208,12 +255,14 @@ def torch_Generator(device=None):
|
|
|
|
|
return original_torch_Generator(device)
|
|
|
|
|
|
|
|
|
|
original_torch_load = torch.load
|
|
|
|
|
@wraps(torch.load)
|
|
|
|
|
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
|
|
|
|
if check_device(map_location):
|
|
|
|
|
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Hijack Functions:
|
|
|
|
|
def ipex_hijacks():
|
|
|
|
|
torch.tensor = torch_tensor
|
|
|
|
|
@@ -232,7 +281,7 @@ def ipex_hijacks():
|
|
|
|
|
torch.backends.cuda.sdp_kernel = return_null_context
|
|
|
|
|
torch.nn.DataParallel = DummyDataParallel
|
|
|
|
|
torch.UntypedStorage.is_cuda = is_cuda
|
|
|
|
|
torch.autocast = ipex_autocast
|
|
|
|
|
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
|
|
|
|
|
|
|
|
|
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
|
|
|
|
torch.nn.functional.group_norm = functional_group_norm
|
|
|
|
|
@@ -244,5 +293,6 @@ def ipex_hijacks():
|
|
|
|
|
|
|
|
|
|
torch.bmm = torch_bmm
|
|
|
|
|
torch.cat = torch_cat
|
|
|
|
|
if not torch.xpu.has_fp64_dtype():
|
|
|
|
|
if not device_supports_fp64:
|
|
|
|
|
torch.from_numpy = from_numpy
|
|
|
|
|
torch.as_tensor = as_tensor
|
|
|
|
|
|