Compare commits

...

2 Commits

Author SHA1 Message Date
Kohya S
0e8b5d4af9 Merge branch 'sd3' into no-grad-block-swap 2025-05-24 18:50:25 +09:00
kohya-ss
c898e4e536 fix: optimize weight device swapping with no_grad context 2025-03-17 21:28:02 +09:00

View File

@@ -42,19 +42,20 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
with torch.no_grad():
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
@@ -75,14 +76,14 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device()
synchronize_device(device)
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device()
synchronize_device(device)
def weighs_to_device(layer: nn.Module, device: torch.device):