remove unused weight swapping functions from utils.py

This commit is contained in:
Kohya S
2024-11-05 23:27:41 +09:00
parent 81c0c965a2
commit aab943cea3

View File

@@ -94,26 +94,6 @@ def setup_logging(args=None, log_level=None, reset=False):
# region PyTorch utils
# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.")
# # cpu_tensor = module_to_cuda.weight.data
# # cuda_tensor = module_to_cpu.weight.data
# # assert cuda_tensor.device.type == "cuda"
# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cuda_tensor.copy_(cpu_tensor, non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True)
# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor
# cuda_tensor_view = module_to_cpu.weight.data
# cpu_tensor_view = module_to_cuda.weight.data
# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone()
# module_to_cuda.weight.data = cuda_tensor_view
# module_to_cuda.weight.data.copy_(cpu_tensor_view)
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
@@ -143,171 +123,6 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()
events = []
with torch.cuda.stream(stream_to_cpu):
# cuda to offload
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)
with torch.cuda.stream(stream_to_cuda):
# cpu to cuda
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events):
event.synchronize()
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# offload to cpu
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip(
weight_swap_jobs, offloaded_weights
):
module_to_cpu.weight.data = offloaded_weight
stream_to_cuda.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()
# cuda to offload
events = []
with torch.cuda.stream(stream_to_cpu):
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream_to_cpu)
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)
# cpu to cuda
with torch.cuda.stream(stream_to_cuda):
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip(
weight_swap_jobs, events, offloaded_weights
):
event.synchronize()
cuda_data_view.record_stream(stream_to_cuda)
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
module_to_cpu.weight.data = offloaded_weight
stream_to_cuda.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
# torch.cuda.current_stream().wait_stream(stream_to_cuda)
# for job in weight_swap_jobs:
# job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor
def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")):
# one of the modules must have the tensor to offload
module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.offloaded_weight.pin_memory()
offloaded_weight = (
module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight
)
assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu"
weight_swap_jobs.append(
(module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight)
)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to offload
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.record_stream(stream)
offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# offload to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
module_to_cpu.weight.data = offloaded_weight
offloaded_weight = cpu_data_view
module_to_cpu.offloaded_weight = offloaded_weight
module_to_cuda.offloaded_weight = offloaded_weight
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")):
# one of the modules must have the tensor to cache
module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.__cached_cpu_weight.pin_memory()
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True)
module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True)
torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish
torch.cuda.empty_cache()
# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda"
# weight_on_cuda = module_to_cpu.weight
# weight_on_cpu = module_to_cuda.weight
# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True)
# event = torch.cuda.current_stream().record_event()
# event.synchronize()
# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True)
# weight_on_cpu.data = cuda_to_cpu_data
# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad
# module_to_cpu.weight = weight_on_cpu
# module_to_cuda.weight = weight_on_cuda
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None: