mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
feat: add CUDA device compatibility validation and corresponding tests
This commit is contained in:
@@ -87,6 +87,49 @@ def get_preferred_device() -> torch.device:
|
||||
return device
|
||||
|
||||
|
||||
|
||||
def _normalize_cuda_arch(arch) -> Optional[str]:
|
||||
if isinstance(arch, str):
|
||||
return arch if arch.startswith("sm_") else None
|
||||
if isinstance(arch, (tuple, list)) and len(arch) >= 2:
|
||||
return f"sm_{int(arch[0])}{int(arch[1])}"
|
||||
return None
|
||||
|
||||
|
||||
def validate_cuda_device_compatibility(device: Optional[Union[str, torch.device]] = None):
|
||||
if not HAS_CUDA:
|
||||
return
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
if device.type != "cuda":
|
||||
return
|
||||
|
||||
get_arch_list = getattr(torch.cuda, "get_arch_list", None)
|
||||
if get_arch_list is None:
|
||||
return
|
||||
|
||||
try:
|
||||
supported_arches = sorted(
|
||||
{arch_name for arch_name in (_normalize_cuda_arch(arch) for arch in get_arch_list()) if arch_name is not None}
|
||||
)
|
||||
device_arch = _normalize_cuda_arch(torch.cuda.get_device_capability(device))
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
if supported_arches and device_arch is not None and device_arch not in supported_arches:
|
||||
cuda_version = getattr(torch.version, "cuda", None)
|
||||
cuda_suffix = f" with CUDA {cuda_version}" if cuda_version else ""
|
||||
supported = ", ".join(supported_arches)
|
||||
raise RuntimeError(
|
||||
f"CUDA device '{device_name}' reports {device_arch}, but this PyTorch build{cuda_suffix} only supports {supported}. "
|
||||
+ "Install a PyTorch build that includes kernels for this GPU from https://pytorch.org/get-started/locally/ or build PyTorch from source."
|
||||
)
|
||||
|
||||
def init_ipex():
|
||||
"""
|
||||
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
||||
|
||||
@@ -30,7 +30,7 @@ from tqdm import tqdm
|
||||
from packaging.version import Version
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.device_utils import init_ipex, clean_memory_on_device, validate_cuda_device_compatibility
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
|
||||
|
||||
init_ipex()
|
||||
@@ -5500,6 +5500,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
dynamo_backend=dynamo_backend,
|
||||
deepspeed_plugin=deepspeed_plugin,
|
||||
)
|
||||
validate_cuda_device_compatibility(accelerator.device)
|
||||
print("accelerator device:", accelerator.device)
|
||||
return accelerator
|
||||
|
||||
|
||||
24
tests/test_device_utils.py
Normal file
24
tests/test_device_utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library import device_utils
|
||||
|
||||
|
||||
def test_validate_cuda_device_compatibility_raises_for_unsupported_arch(monkeypatch):
|
||||
monkeypatch.setattr(device_utils, "HAS_CUDA", True)
|
||||
monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"])
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (12, 0))
|
||||
monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Blackwell Test GPU")
|
||||
monkeypatch.setattr(torch.version, "cuda", "12.4", raising=False)
|
||||
|
||||
with pytest.raises(RuntimeError, match="sm_120"):
|
||||
device_utils.validate_cuda_device_compatibility("cuda")
|
||||
|
||||
|
||||
def test_validate_cuda_device_compatibility_allows_supported_arch(monkeypatch):
|
||||
monkeypatch.setattr(device_utils, "HAS_CUDA", True)
|
||||
monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"])
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (9, 0))
|
||||
monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Hopper Test GPU")
|
||||
|
||||
device_utils.validate_cuda_device_compatibility("cuda")
|
||||
Reference in New Issue
Block a user