From 6d3e51431be4f207b2ebddc975c6b0a2196576ad Mon Sep 17 00:00:00 2001 From: umisetokikaze Date: Wed, 11 Mar 2026 22:25:13 +0900 Subject: [PATCH] feat: add CUDA device compatibility validation and corresponding tests --- library/device_utils.py | 43 ++++++++++++++++++++++++++++++++++++++ library/train_util.py | 3 ++- tests/test_device_utils.py | 24 +++++++++++++++++++++ 3 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/test_device_utils.py diff --git a/library/device_utils.py b/library/device_utils.py index 2d59b64b..e91ec162 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -87,6 +87,49 @@ def get_preferred_device() -> torch.device: return device + +def _normalize_cuda_arch(arch) -> Optional[str]: + if isinstance(arch, str): + return arch if arch.startswith("sm_") else None + if isinstance(arch, (tuple, list)) and len(arch) >= 2: + return f"sm_{int(arch[0])}{int(arch[1])}" + return None + + +def validate_cuda_device_compatibility(device: Optional[Union[str, torch.device]] = None): + if not HAS_CUDA: + return + + if device is None: + device = torch.device("cuda") + elif isinstance(device, str): + device = torch.device(device) + + if device.type != "cuda": + return + + get_arch_list = getattr(torch.cuda, "get_arch_list", None) + if get_arch_list is None: + return + + try: + supported_arches = sorted( + {arch_name for arch_name in (_normalize_cuda_arch(arch) for arch in get_arch_list()) if arch_name is not None} + ) + device_arch = _normalize_cuda_arch(torch.cuda.get_device_capability(device)) + device_name = torch.cuda.get_device_name(device) + except Exception: + return + + if supported_arches and device_arch is not None and device_arch not in supported_arches: + cuda_version = getattr(torch.version, "cuda", None) + cuda_suffix = f" with CUDA {cuda_version}" if cuda_version else "" + supported = ", ".join(supported_arches) + raise RuntimeError( + f"CUDA device '{device_name}' reports {device_arch}, but this PyTorch build{cuda_suffix} only supports {supported}. " + + "Install a PyTorch build that includes kernels for this GPU from https://pytorch.org/get-started/locally/ or build PyTorch from source." + ) + def init_ipex(): """ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. diff --git a/library/train_util.py b/library/train_util.py index efc51fb1..c8b45487 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -30,7 +30,7 @@ from tqdm import tqdm from packaging.version import Version import torch -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import init_ipex, clean_memory_on_device, validate_cuda_device_compatibility from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy init_ipex() @@ -5500,6 +5500,7 @@ def prepare_accelerator(args: argparse.Namespace): dynamo_backend=dynamo_backend, deepspeed_plugin=deepspeed_plugin, ) + validate_cuda_device_compatibility(accelerator.device) print("accelerator device:", accelerator.device) return accelerator diff --git a/tests/test_device_utils.py b/tests/test_device_utils.py new file mode 100644 index 00000000..77d44b73 --- /dev/null +++ b/tests/test_device_utils.py @@ -0,0 +1,24 @@ +import pytest +import torch + +from library import device_utils + + +def test_validate_cuda_device_compatibility_raises_for_unsupported_arch(monkeypatch): + monkeypatch.setattr(device_utils, "HAS_CUDA", True) + monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"]) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (12, 0)) + monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Blackwell Test GPU") + monkeypatch.setattr(torch.version, "cuda", "12.4", raising=False) + + with pytest.raises(RuntimeError, match="sm_120"): + device_utils.validate_cuda_device_compatibility("cuda") + + +def test_validate_cuda_device_compatibility_allows_supported_arch(monkeypatch): + monkeypatch.setattr(device_utils, "HAS_CUDA", True) + monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"]) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (9, 0)) + monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Hopper Test GPU") + + device_utils.validate_cuda_device_compatibility("cuda")