mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
remove unused weight swapping functions from utils.py
This commit is contained in:
185
library/utils.py
185
library/utils.py
@@ -94,26 +94,6 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
|
||||
# region PyTorch utils
|
||||
|
||||
# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.")
|
||||
# # cpu_tensor = module_to_cuda.weight.data
|
||||
# # cuda_tensor = module_to_cpu.weight.data
|
||||
# # assert cuda_tensor.device.type == "cuda"
|
||||
# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True)
|
||||
# # torch.cuda.current_stream().synchronize()
|
||||
# # cuda_tensor.copy_(cpu_tensor, non_blocking=True)
|
||||
# # torch.cuda.current_stream().synchronize()
|
||||
# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True)
|
||||
# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor
|
||||
# cuda_tensor_view = module_to_cpu.weight.data
|
||||
# cpu_tensor_view = module_to_cuda.weight.data
|
||||
# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone()
|
||||
# module_to_cuda.weight.data = cuda_tensor_view
|
||||
# module_to_cuda.weight.data.copy_(cpu_tensor_view)
|
||||
|
||||
|
||||
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
@@ -143,171 +123,6 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
stream_to_cpu = torch.cuda.Stream()
|
||||
stream_to_cuda = torch.cuda.Stream()
|
||||
|
||||
events = []
|
||||
with torch.cuda.stream(stream_to_cpu):
|
||||
# cuda to offload
|
||||
offloaded_weights = []
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
|
||||
event = torch.cuda.Event()
|
||||
event.record(stream=stream_to_cpu)
|
||||
events.append(event)
|
||||
|
||||
with torch.cuda.stream(stream_to_cuda):
|
||||
# cpu to cuda
|
||||
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events):
|
||||
event.synchronize()
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
# offload to cpu
|
||||
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip(
|
||||
weight_swap_jobs, offloaded_weights
|
||||
):
|
||||
module_to_cpu.weight.data = offloaded_weight
|
||||
|
||||
stream_to_cuda.synchronize()
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
stream_to_cpu = torch.cuda.Stream()
|
||||
stream_to_cuda = torch.cuda.Stream()
|
||||
|
||||
# cuda to offload
|
||||
events = []
|
||||
with torch.cuda.stream(stream_to_cpu):
|
||||
offloaded_weights = []
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream_to_cpu)
|
||||
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
|
||||
|
||||
event = torch.cuda.Event()
|
||||
event.record(stream=stream_to_cpu)
|
||||
events.append(event)
|
||||
|
||||
# cpu to cuda
|
||||
with torch.cuda.stream(stream_to_cuda):
|
||||
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip(
|
||||
weight_swap_jobs, events, offloaded_weights
|
||||
):
|
||||
event.synchronize()
|
||||
cuda_data_view.record_stream(stream_to_cuda)
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
module_to_cpu.weight.data = offloaded_weight
|
||||
|
||||
stream_to_cuda.synchronize()
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
# torch.cuda.current_stream().wait_stream(stream_to_cuda)
|
||||
# for job in weight_swap_jobs:
|
||||
# job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor
|
||||
|
||||
|
||||
def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")):
|
||||
# one of the modules must have the tensor to offload
|
||||
module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
|
||||
module_to_cpu.offloaded_weight.pin_memory()
|
||||
offloaded_weight = (
|
||||
module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight
|
||||
)
|
||||
assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu"
|
||||
weight_swap_jobs.append(
|
||||
(module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight)
|
||||
)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to offload
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
# offload to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = offloaded_weight
|
||||
offloaded_weight = cpu_data_view
|
||||
module_to_cpu.offloaded_weight = offloaded_weight
|
||||
module_to_cuda.offloaded_weight = offloaded_weight
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
|
||||
weight_swap_jobs = []
|
||||
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")):
|
||||
# one of the modules must have the tensor to cache
|
||||
module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
|
||||
module_to_cpu.__cached_cpu_weight.pin_memory()
|
||||
|
||||
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
|
||||
|
||||
for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True)
|
||||
module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True)
|
||||
|
||||
torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
|
||||
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
|
||||
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
|
||||
# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda"
|
||||
# weight_on_cuda = module_to_cpu.weight
|
||||
# weight_on_cpu = module_to_cuda.weight
|
||||
# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True)
|
||||
# event = torch.cuda.current_stream().record_event()
|
||||
# event.synchronize()
|
||||
# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True)
|
||||
# weight_on_cpu.data = cuda_to_cpu_data
|
||||
# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad
|
||||
|
||||
# module_to_cpu.weight = weight_on_cpu
|
||||
# module_to_cuda.weight = weight_on_cuda
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
for module in layer.modules():
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
|
||||
Reference in New Issue
Block a user