fix crashing in DDP training closes #1751

This commit is contained in:
Kohya S
2024-11-02 15:32:16 +09:00
parent e0db59695f
commit 5e32ee26a1

View File

@@ -838,11 +838,31 @@ def train(args):
accelerator.log({}, step=0)
# show model device and dtype
logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None")
logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None")
logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None")
logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None")
logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None")
logger.info(
f"mmdit device: {accelerator.unwrap_model(mmdit).device}, dtype: {accelerator.unwrap_model(mmdit).dtype}"
if mmdit
else "mmdit is None"
)
logger.info(
f"clip_l device: {accelerator.unwrap_model(clip_l).device}, dtype: {accelerator.unwrap_model(clip_l).dtype}"
if clip_l
else "clip_l is None"
)
logger.info(
f"clip_g device: {accelerator.unwrap_model(clip_g).device}, dtype: {accelerator.unwrap_model(clip_g).dtype}"
if clip_g
else "clip_g is None"
)
logger.info(
f"t5xxl device: {accelerator.unwrap_model(t5xxl).device}, dtype: {accelerator.unwrap_model(t5xxl).dtype}"
if t5xxl
else "t5xxl is None"
)
logger.info(
f"vae device: {accelerator.unwrap_model(vae).device}, dtype: {accelerator.unwrap_model(vae).dtype}"
if vae is not None
else "vae is None"
)
loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0