set static graph flag when DDP ref #1363

This commit is contained in:
Kohya S
2024-06-09 19:26:09 +09:00
parent e5bab69e3a
commit 58fb64819a

View File

@@ -289,6 +289,9 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else: