apply offloading method runable for all trainer

This commit is contained in:
BootsofLagrangian
2024-02-05 22:42:06 +09:00
parent 3970bf4080
commit 7d2a9268b9
3 changed files with 15 additions and 0 deletions

View File

@@ -251,6 +251,11 @@ def train(args):
if args.deepspeed:
# wrapping model
import deepspeed
if args.offload_optimizer_device is not None:
accelerator.print('[DeepSpeed] start to manually build cpu_adam.')
deepspeed.ops.op_builder.CPUAdamBuilder().load()
accelerator.print('[DeepSpeed] building cpu_adam done.')
class DeepSpeedModel(torch.nn.Module):
def __init__(self, unet, text_encoder) -> None:
super().__init__()