feat: add CUDA device compatibility validation and corresponding tests

This commit is contained in:
umisetokikaze
2026-03-11 22:25:13 +09:00
parent c42ad076c6
commit 6d3e51431b
3 changed files with 69 additions and 1 deletions

View File

@@ -87,6 +87,49 @@ def get_preferred_device() -> torch.device:
return device
def _normalize_cuda_arch(arch) -> Optional[str]:
if isinstance(arch, str):
return arch if arch.startswith("sm_") else None
if isinstance(arch, (tuple, list)) and len(arch) >= 2:
return f"sm_{int(arch[0])}{int(arch[1])}"
return None
def validate_cuda_device_compatibility(device: Optional[Union[str, torch.device]] = None):
if not HAS_CUDA:
return
if device is None:
device = torch.device("cuda")
elif isinstance(device, str):
device = torch.device(device)
if device.type != "cuda":
return
get_arch_list = getattr(torch.cuda, "get_arch_list", None)
if get_arch_list is None:
return
try:
supported_arches = sorted(
{arch_name for arch_name in (_normalize_cuda_arch(arch) for arch in get_arch_list()) if arch_name is not None}
)
device_arch = _normalize_cuda_arch(torch.cuda.get_device_capability(device))
device_name = torch.cuda.get_device_name(device)
except Exception:
return
if supported_arches and device_arch is not None and device_arch not in supported_arches:
cuda_version = getattr(torch.version, "cuda", None)
cuda_suffix = f" with CUDA {cuda_version}" if cuda_version else ""
supported = ", ".join(supported_arches)
raise RuntimeError(
f"CUDA device '{device_name}' reports {device_arch}, but this PyTorch build{cuda_suffix} only supports {supported}. "
+ "Install a PyTorch build that includes kernels for this GPU from https://pytorch.org/get-started/locally/ or build PyTorch from source."
)
def init_ipex():
"""
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.

View File

@@ -30,7 +30,7 @@ from tqdm import tqdm
from packaging.version import Version
import torch
from library.device_utils import init_ipex, clean_memory_on_device
from library.device_utils import init_ipex, clean_memory_on_device, validate_cuda_device_compatibility
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
init_ipex()
@@ -5500,6 +5500,7 @@ def prepare_accelerator(args: argparse.Namespace):
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
validate_cuda_device_compatibility(accelerator.device)
print("accelerator device:", accelerator.device)
return accelerator

View File

@@ -0,0 +1,24 @@
import pytest
import torch
from library import device_utils
def test_validate_cuda_device_compatibility_raises_for_unsupported_arch(monkeypatch):
monkeypatch.setattr(device_utils, "HAS_CUDA", True)
monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"])
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (12, 0))
monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Blackwell Test GPU")
monkeypatch.setattr(torch.version, "cuda", "12.4", raising=False)
with pytest.raises(RuntimeError, match="sm_120"):
device_utils.validate_cuda_device_compatibility("cuda")
def test_validate_cuda_device_compatibility_allows_supported_arch(monkeypatch):
monkeypatch.setattr(device_utils, "HAS_CUDA", True)
monkeypatch.setattr(torch.cuda, "get_arch_list", lambda: ["sm_80", "sm_90"])
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda device=None: (9, 0))
monkeypatch.setattr(torch.cuda, "get_device_name", lambda device=None: "Hopper Test GPU")
device_utils.validate_cuda_device_compatibility("cuda")