mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
set static graph flag when DDP ref #1363
This commit is contained in:
@@ -289,6 +289,9 @@ def train(args):
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if isinstance(unet, DDP):
|
||||
unet._set_static_graph() # avoid error for multiple use of the parameter
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user